diff options
Diffstat (limited to 'lib9p/protogen/__init__.py')
-rw-r--r-- | lib9p/protogen/__init__.py | 749 |
1 files changed, 2 insertions, 747 deletions
diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 37cf6f5..c2c6173 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -10,757 +10,12 @@ import typing import idl -from . import c9util, cutil, h, idlutil - -# This strives to be "general-purpose" in that it just acts on the -# *.9p inputs; but (unfortunately?) there are a few special-cases in -# this script, marked with "SPECIAL". - +from . import c, h # pylint: disable=unused-variable __all__ = ["main"] -# Generate .c ################################################################## - - -def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: - cutil.ifdef_init() - - ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ - -#include <stdbool.h> -#include <stddef.h> /* for size_t */ -#include <inttypes.h> /* for PRI* macros */ -#include <string.h> /* for memset() */ - -#include <libmisc/assert.h> - -#include <lib9p/9p.h> - -#include "internal.h" -""" - - # utilities ################################################################ - ret += """ -/* utilities ******************************************************************/ -""" - - def used(arg: str) -> str: - return arg - - def unused(arg: str) -> str: - return f"LM_UNUSED({arg})" - - id2typ: dict[int, idl.Message] = {} - for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: - id2typ[msg.msgid] = msg - - def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: - ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" - for ver in ["unknown", *sorted(versions)]: - if ver != "unknown": - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" - for n in range(*rng): - xmsg: idl.Message | None = id2typ.get(n, None) - if xmsg: - if ver == "unknown": # SPECIAL (initialization) - if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: - xmsg = None - else: - if ver not in xmsg.in_versions: - xmsg = None - if xmsg: - ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" - ret += "\t},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - return ret - - for v in sorted(versions): - ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" - ret += ( - f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" - ) - ret += "#else\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" - ret += "#endif\n" - ret += "\n" - ret += "/**\n" - ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" - ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" - ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" - ret += " * several version checks together.\n" - ret += " */\n" - ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" - - # strings ################################################################## - ret += f""" -/* strings ********************************************************************/ - -const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ -""" - for ver in ["unknown", *sorted(versions)]: - if ver in versions: - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += "\n" - ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" - ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) - - # bitmasks ################################################################# - ret += """ -/* bitmasks *******************************************************************/ -""" - for typ in typs: - if not isinstance(typ, idl.Bitfield): - continue - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" - verwidth = max(len(ver) for ver in versions) - for ver in sorted(versions): - ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) - ret += ( - f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" - + "".join( - ( - "1" - if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD) - and ver in bit.in_versions - else "0" - ) - for bit in reversed(typ.bits) - ) - + ",\n" - ) - ret += cutil.ifdef_pop(1) - ret += "};\n" - ret += cutil.ifdef_pop(0) - - # validate_* ############################################################### - ret += """ -/* validate_* *****************************************************************/ - -LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { -\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) -\t\t/* If needed-net-size overflowed uint32_t, then -\t\t * there's no way that actual-net-size will live up to -\t\t * that. */ -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\tif (ctx->net_offset > ctx->net_size) -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\treturn false; -} - -LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { -\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) -\t\t/* If needed-host-size overflowed size_t, then there's -\t\t * no way that actual-net-size will live up to -\t\t * that. */ -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\treturn false; -} - -LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, - size_t cnt, - _validate_fn_t item_fn, size_t item_host_size) { -\tfor (size_t i = 0; i < cnt; i++) -\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) -\t\t\treturn true; -\treturn false; -} - -LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } -LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } -LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } -LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } -""" - - def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: - return bool( - member.max or member.val or any(m.cnt == member for m in typ.members) - ) - - for typ in idlutil.topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" - - match typ: - case idl.Number(): - ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" - case idl.Bitfield(): - ret += f"\t if (validate_{typ.static_size}(ctx))\n" - ret += "\t\treturn true;\n" - ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" - if typ.static_size == 1: - ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" - else: - ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" - ret += "\tif (val & ~mask)\n" - ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' - ret += "\treturn false;\n" - case idl.Struct(): # and idl.Message() - if len(typ.members) == 0: - ret += "\treturn false;\n" - ret += "}\n" - continue - - # Pass 1 - declare value variables - for member in typ.members: - if should_save_value(typ, member): - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" - ret += cutil.ifdef_pop(1) - - # Pass 2 - declare offset variables - mark_offset: set[str] = set() - for member in typ.members: - for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): - if tok.symname[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" - mark_offset.add(tok.symname[1:]) - - # Pass 3 - main pass - ret += "\treturn false\n" - for member in typ.members: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += "\t || " - if member.in_versions != typ.in_versions: - ret += "( " + c9util.ver_cond(member.in_versions) + " && " - if member.cnt is not None: - if member.typ.static_size == 1: # SPECIAL (zerocopy) - ret += f"_validate_size_net(ctx, {member.cnt.membname})" - else: - ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" - if typ.typname == "s": # SPECIAL (string) - ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' - else: - if should_save_value(typ, member): - ret += "(" - if member.membname in mark_offset: - ret += f"({{ _{member.membname}_offset = ctx->net_offset; " - ret += f"validate_{member.typ.typname}(ctx)" - if member.membname in mark_offset: - ret += "; })" - if should_save_value(typ, member): - nbytes = member.static_size - assert nbytes - if nbytes == 1: - ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" - else: - ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" - if member.in_versions != typ.in_versions: - ret += " )" - ret += "\n" - - # Pass 4 - validate ,max= and ,val= constraints - for member in typ.members: - - def lookup_sym(sym: str) -> str: - match sym: - case "end": - return "ctx->net_offset" - case _: - assert sym.startswith("&") - return f"_{sym[1:]}_offset" - - if member.max: - assert member.static_size - nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' - if member.val: - assert member.static_size - nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' - - ret += cutil.ifdef_pop(1) - ret += "\t ;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # unmarshal_* ############################################################## - ret += """ -/* unmarshal_* ****************************************************************/ - -LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { -\t*out = ctx->net_bytes[ctx->net_offset]; -\tctx->net_offset += 1; -} - -LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { -\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 2; -} - -LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { -\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 4; -} - -LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { -\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 8; -} -""" - for typ in idlutil.topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" - match typ: - case idl.Number(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" - case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" - case idl.Struct(): - ret += "\tmemset(out, 0, sizeof(*out));\n" - - for member in typ.members: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - if member.val: - ret += f"\tctx->net_offset += {member.static_size};\n" - continue - ret += "\t" - - prefix = "\t" - if member.in_versions != typ.in_versions: - ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " - prefix = "\t\t" - if member.cnt: - if member.in_versions != typ.in_versions: - ret += "{\n" - ret += prefix - if member.typ.static_size == 1: # SPECIAL (string, zerocopy) - ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" - ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" - else: - ret += f"out->{member.membname} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" - ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" - ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" - if member.in_versions != typ.in_versions: - ret += "\t}\n" - else: - ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" - ret += cutil.ifdef_pop(1) - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # marshal_* ################################################################ - ret += """ -/* marshal_* ******************************************************************/ - -""" - ret += cutil.macro( - "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" - "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" - "\t\tctx->net_iov_cnt++;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" - "\tctx->net_iov_cnt++;\n" - ) - ret += cutil.macro( - "#define MARSHAL_BYTES(ctx, data, len)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" - "\tctx->net_copied_size += len;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U8LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tctx->net_copied[ctx->net_copied_size] = val;\n" - "\tctx->net_copied_size += 1;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U16LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 2;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U32LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 4;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U64LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 8;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" - ) - - class OffsetExpr: - static: int - cond: dict[frozenset[str], "OffsetExpr"] - rep: list[tuple[idlutil.Path, "OffsetExpr"]] - - def __init__(self) -> None: - self.static = 0 - self.rep = [] - self.cond = {} - - def add(self, other: "OffsetExpr") -> None: - self.static += other.static - self.rep += other.rep - for k, v in other.cond.items(): - if k in self.cond: - self.cond[k].add(v) - else: - self.cond[k] = v - - def gen_c( - self, - dsttyp: str, - dstvar: str, - root: str, - indent_depth: int, - loop_depth: int, - ) -> str: - oneline: list[str] = [] - multiline = "" - if self.static: - oneline.append(str(self.static)) - for cnt, sub in self.rep: - if not sub.cond and not sub.rep: - if sub.static == 1: - oneline.append(cnt.c_str(root)) - else: - oneline.append(f"({cnt.c_str(root)})*{sub.static}") - continue - loopvar = chr(ord("i") + loop_depth) - multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" - multiline += sub.gen_c( - "", dstvar, root, indent_depth + 1, loop_depth + 1 - ) - multiline += f"{'\t'*indent_depth}}}\n" - for vers, sub in self.cond.items(): - multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) - multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" - multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) - multiline += f"{'\t'*indent_depth}}}\n" - multiline += cutil.ifdef_pop(indent_depth) - if dsttyp: - if not oneline: - oneline.append("0") - ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" - elif oneline: - ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" - ret += multiline - return ret - - type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd] - - def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: - if not isinstance(typ, idl.Struct): - assert typ.static_size - ret = OffsetExpr() - ret.static = typ.static_size - return ret - - stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]] - - def pop_root() -> None: - assert False - - def pop_cond() -> None: - nonlocal stack - key = frozenset(stack[-1][0].elems[-1].in_versions) - if key in stack[-2][1].cond: - stack[-2][1].cond[key].add(stack[-1][1]) - else: - stack[-2][1].cond[key] = stack[-1][1] - stack = stack[:-1] - - def pop_rep() -> None: - nonlocal stack - member_path = stack[-1][0] - member = member_path.elems[-1] - assert member.cnt - cnt_path = member_path.parent().add(member.cnt) - stack[-2][1].rep.append((cnt_path, stack[-1][1])) - stack = stack[:-1] - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]: - nonlocal recurse - - ret = recurse(path) - if ret != idlutil.WalkCmd.KEEP_GOING: - return ret, None - - nonlocal stack - stack_len = len(stack) - - def pop() -> None: - nonlocal stack - nonlocal stack_len - while len(stack) > stack_len: - stack[-1][2]() - - if path.elems: - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - if child.in_versions < parent.in_versions: - stack.append((path, OffsetExpr(), pop_cond)) - if child.cnt: - stack.append((path, OffsetExpr(), pop_rep)) - if not isinstance(child.typ, idl.Struct): - assert child.typ.static_size - stack[-1][1].static += child.typ.static_size - return ret, pop - - stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)] - idlutil.walk(typ, handle) - return stack[0][1] - - def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd: - return idlutil.WalkCmd.KEEP_GOING - - def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]: - def ret(path: idlutil.Path) -> idlutil.WalkCmd: - if len(path.elems) == 1 and path.elems[0].membname == name: - return idlutil.WalkCmd.ABORT - return idlutil.WalkCmd.KEEP_GOING - - return ret - - for typ in typs: - if not ( - isinstance(typ, idl.Message) or typ.typname == "stat" - ): # SPECIAL (include stat) - continue - assert isinstance(typ, idl.Struct) - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" - - # Pass 1 - check size - max_size = max(typ.max_size(v) for v in typ.in_versions) - - if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) - ret += get_offset_expr(typ, go_to_end).gen_c( - "uint64_t", "needed_size", "val->", 1, 0 - ) - ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" - else: - ret += get_offset_expr(typ, go_to_end).gen_c( - "uint32_t", "needed_size", "val->", 1, 0 - ) - ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" - if isinstance(typ, idl.Message): # SPECIAL (disable for stat) - ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' - ret += f'\t\t\t"{typ.typname}",\n' - ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' - ret += "\t\t\tctx->ctx->max_msg_size);\n" - ret += "\t\treturn true;\n" - ret += "\t}\n" - - # Pass 2 - write data - ifdef_depth = 1 - stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)] - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: - nonlocal ret - nonlocal ifdef_depth - nonlocal stack - stack_len = len(stack) - - def pop() -> None: - nonlocal ret - nonlocal ifdef_depth - nonlocal stack - nonlocal stack_len - while len(stack) > stack_len: - ret += f"{'\t'*(len(stack)-1)}}}\n" - if stack[-1][1]: - ifdef_depth -= 1 - ret += cutil.ifdef_pop(ifdef_depth) - stack = stack[:-1] - - loopdepth = sum(1 for elem in path.elems if elem.cnt) - struct = path.elems[-1].typ if path.elems else path.root - if isinstance(struct, idl.Struct): - offsets: list[str] = [] - for member in struct.members: - if not member.val: - continue - for tok in member.val.tokens: - if not isinstance(tok, idl.ExprSym): - continue - if tok.symname == "end" or tok.symname.startswith("&"): - if tok.symname not in offsets: - offsets.append(tok.symname) - for name in offsets: - name_prefix = "offsetof_" + "".join( - m.membname + "_" for m in path.elems - ) - if name == "end": - if not path.elems: - nonlocal max_size - if max_size > cutil.UINT32_MAX: - ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" - else: - ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" - continue - recurse: OffsetExprRecursion = go_to_end - else: - assert name.startswith("&") - name = name[1:] - recurse = go_to_tok(name) - expr = get_offset_expr(struct, recurse) - expr_prefix = path.c_str("val->", loopdepth) - if not expr_prefix.endswith(">"): - expr_prefix += "." - ret += expr.gen_c( - "uint32_t", - name_prefix + name, - expr_prefix, - len(stack), - loopdepth, - ) - if path.elems: - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - if child.in_versions < parent.in_versions: - ret += cutil.ifdef_push( - ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) - ) - ifdef_depth += 1 - ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" - stack.append((path, True)) - if child.cnt: - cnt_path = path.parent().add(child.cnt) - if child.typ.static_size == 1: # SPECIAL (zerocopy) - if path.root.typname == "stat": # SPECIAL (stat) - ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - else: - ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - return idlutil.WalkCmd.KEEP_GOING, pop - loopvar = chr(ord("i") + loopdepth - 1) - ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" - stack.append((path, False)) - if not isinstance(child.typ, idl.Struct): - if child.val: - - def lookup_sym(sym: str) -> str: - nonlocal path - if sym.startswith("&"): - sym = sym[1:] - return ( - "offsetof_" - + "".join(m.membname + "_" for m in path.elems[:-1]) - + sym - ) - - val = c9util.idl_expr(child.val, lookup_sym) - else: - val = path.c_str("val->") - if isinstance(child.typ, idl.Bitfield): - val += f" & {child.typ.typname}_masks[ctx->ctx->version]" - ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" - return idlutil.WalkCmd.KEEP_GOING, pop - - idlutil.walk(typ, handle) - del handle - del stack - del max_size - - ret += "\treturn false;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # function tables ########################################################## - ret += """ -/* function tables ************************************************************/ -""" - - ret += "\n" - ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" - rerror = next(typ for typ in typs if typ.typname == "Rerror") - ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) - for ver in sorted(versions): - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += "\n" - ret += cutil.macro( - f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" - f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" - f"\t\t.validate = validate_##typ,\n" - f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" - f"\t}}\n" - ) - ret += cutil.macro( - f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" - f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" - f"\t}}\n" - ) - ret += "\n" - ret += msg_table( - "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) - ) - - ret += f""" -LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ -\treturn validate_stat(ctx); -}} -LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ -\tunmarshal_stat(ctx, out); -}} -LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ -\treturn marshal_stat(ctx, val); -}} -""" - - ############################################################################ - return ret - - -# Main ######################################################################### - - def main() -> None: if typing.TYPE_CHECKING: @@ -799,4 +54,4 @@ def main() -> None: ) as fh: fh.write(h.gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: - fh.write(gen_c(versions, typs)) + fh.write(c.gen_c(versions, typs)) |