diff options
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 750 |
1 files changed, 622 insertions, 128 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index adb15ce..f2b4f13 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -5,6 +5,7 @@ # Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> # SPDX-License-Identifier: AGPL-3.0-or-later +import enum import graphlib import os.path import sys @@ -145,6 +146,9 @@ def ifdef_pop(n: int) -> str: return ret +# topo_sorted() ################################################################ + + def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]: ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter() for typ in typs: @@ -163,6 +167,189 @@ def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]: return ts.static_order() +# walk() ####################################################################### + + +class Path: + root: idl.Type + elems: list[idl.StructMember] + + def __init__( + self, root: idl.Type, elems: list[idl.StructMember] | None = None + ) -> None: + self.root = root + self.elems = elems if elems is not None else [] + + def add(self, elem: idl.StructMember) -> "Path": + return Path(self.root, self.elems + [elem]) + + def parent(self) -> "Path": + return Path(self.root, self.elems[:-1]) + + def c_str(self, base: str, loopdepth: int = 0) -> str: + ret = base + for i, elem in enumerate(self.elems): + if i > 0: + ret += "." + ret += elem.name + if elem.cnt: + ret += f"[{chr(ord('i')+loopdepth)}]" + loopdepth += 1 + return ret + + def __str__(self) -> str: + return self.c_str(self.root.name + "->") + + +class WalkCmd(enum.Enum): + KEEP_GOING = 1 + DONT_RECURSE = 2 + ABORT = 3 + + +type WalkHandler = typing.Callable[ + [Path], tuple[WalkCmd, typing.Callable[[], None] | None] +] + + +def _walk(path: Path, handle: WalkHandler) -> WalkCmd: + typ = path.elems[-1].typ if path.elems else path.root + + ret, atexit = handle(path) + + if isinstance(typ, idl.Struct): + match ret: + case WalkCmd.KEEP_GOING: + for member in typ.members: + if _walk(path.add(member), handle) == WalkCmd.ABORT: + ret = WalkCmd.ABORT + break + case WalkCmd.DONT_RECURSE: + ret = WalkCmd.KEEP_GOING + case WalkCmd.ABORT: + ret = WalkCmd.ABORT + case _: + assert False, f"invalid cmd: {ret}" + + if atexit: + atexit() + return ret + + +def walk(typ: idl.Type, handle: WalkHandler) -> None: + _walk(Path(typ), handle) + + +# get_buffer_size() ############################################################ + + +class BufferSize: + min_size: int # really just here to sanity-check against typ.min_size(version) + exp_size: int # "expected" or max-reasonable size + max_size: int # really just here to sanity-check against typ.max_size(version) + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + _starts_with_copy: bool + _ends_with_copy: bool + + def __init__(self) -> None: + self.min_size = 0 + self.exp_size = 0 + self.max_size = 0 + self.max_copy = 0 + self.max_copy_extra = "" + self.max_iov = 0 + self.max_iov_extra = "" + self._starts_with_copy = False + self._ends_with_copy = False + + +def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: + assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) + + ret = BufferSize() + + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret.min_size = typ.static_size + ret.exp_size = typ.static_size + ret.max_size = typ.static_size + ret.max_copy = typ.static_size + ret.max_iov = 1 + ret._starts_with_copy = True + ret._ends_with_copy = True + return ret + + def handle(path: Path) -> tuple[WalkCmd, None]: + nonlocal ret + if path.elems: + child = path.elems[-1] + if version not in child.in_versions: + return WalkCmd.DONT_RECURSE, None + if child.cnt: + if child.typ.static_size == 1: # SPECIAL (zerocopy) + ret.max_iov += 1 + # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data + ret.exp_size += 27 if child.name == "utf8" else 8192 + ret.max_size += child.max_cnt + ret._ends_with_copy = False + return WalkCmd.DONT_RECURSE, None + sub = get_buffer_size(child.typ, version) + ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM + ret.max_size += sub.max_size * child.max_cnt + if child.name == "wname" and path.root.name in ( + "Tsread", + "Tswrite", + ): # SPECIAL (9P2000.e) + assert ret._ends_with_copy + assert sub._starts_with_copy + assert not sub._ends_with_copy + ret.max_copy_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" + ) + ret.max_iov_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" + ) + ret.max_iov -= 1 + else: + ret.max_copy += sub.max_copy * child.max_cnt + if sub.max_iov == 1 and sub._starts_with_copy: # is purely copy + ret.max_iov += 1 + else: # contains zero-copy segments + ret.max_iov += sub.max_iov * child.max_cnt + if ret._ends_with_copy and sub._starts_with_copy: + # we can merge this one + ret.max_iov -= 1 + if ( + sub._ends_with_copy + and sub._starts_with_copy + and sub.max_iov > 1 + ): + # we can merge these + ret.max_iov -= child.max_cnt - 1 + ret._ends_with_copy = sub._ends_with_copy + return WalkCmd.DONT_RECURSE, None + elif not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + if not ret._ends_with_copy: + if ret.max_size == 0: + ret._starts_with_copy = True + ret.max_iov += 1 + ret._ends_with_copy = True + ret.min_size += child.typ.static_size + ret.exp_size += child.typ.static_size + ret.max_size += child.typ.static_size + ret.max_copy += child.typ.static_size + return WalkCmd.KEEP_GOING, None + + walk(typ, handle) + assert ret.min_size == typ.min_size(version) + assert ret.max_size == typ.max_size(version) + return ret + + # Generate .h ################################################################## @@ -177,6 +364,8 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str: #endif #include <stdint.h> /* for uint{{n}}_t types */ + +#include <libhw/generic/net.h> /* for struct iovec */ """ id2typ: dict[int, idl.Message] = {} @@ -192,6 +381,14 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str: ret += "\n" ret += f"#ifndef {c_ver_ifdef({ver})}\n" ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n" + if ver == "9P2000.e": # SPECIAL (9P2000.e) + ret += "#else\n" + ret += f"\t#if {c_ver_ifdef({ver})}\n" + ret += "\t\t#ifndef(CONFIG_9P_MAX_9P2000_e_WELEM)\n" + ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += "\t\t#endif\n" + ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" + ret += "\t#endif\n" ret += "#endif\n" ret += f""" @@ -251,16 +448,20 @@ enum {idprefix}version {{ ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) def sum_size(typ: idl.Type, version: str) -> str: - min_size = typ.min_size(version) - max_size = typ.max_size(version) - assert min_size <= max_size and max_size < u64max + sz = get_buffer_size(typ, version) + assert ( + sz.min_size <= sz.exp_size + and sz.exp_size <= sz.max_size + and sz.max_size < u64max + ) ret = "" - if min_size == max_size: - ret += f"size = {min_size:,}" + if sz.min_size == sz.max_size: + ret += f"size = {sz.min_size:,}" else: - ret += f"min_size = {min_size:,} ; max_size = {max_size:,}" - if max_size > u32max: + ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" + if sz.max_size > u32max: ret += " (warning: >UINT32_MAX)" + ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" return ret ret += per_version_comment(typ, sum_size) @@ -331,6 +532,87 @@ enum {idprefix}version {{ ret += "};\n" ret += ifdef_pop(0) + ret += """ +/* containers *****************************************************************/ +""" + ret += "\n" + ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n" + + tmsg_max_iov: dict[str, int] = {} + tmsg_max_copy: dict[str, int] = {} + rmsg_max_iov: dict[str, int] = {} + rmsg_max_copy: dict[str, int] = {} + for typ in typs: + if not isinstance(typ, idl.Message): + continue + if typ.name in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) + continue + max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov + max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy + for version in typ.in_versions: + if version not in max_iov: + max_iov[version] = 0 + max_copy[version] = 0 + sz = get_buffer_size(typ, version) + if sz.max_iov > max_iov[version]: + max_iov[version] = sz.max_iov + if sz.max_copy > max_copy[version]: + max_copy[version] = sz.max_copy + + for name, table in [ + ("tmsg_max_iov", tmsg_max_iov), + ("tmsg_max_copy", tmsg_max_copy), + ("rmsg_max_iov", rmsg_max_iov), + ("rmsg_max_copy", rmsg_max_copy), + ]: + inv: dict[int, set[str]] = {} + for version, maxval in table.items(): + if maxval not in inv: + inv[maxval] = set() + inv[maxval].add(version) + + ret += "\n" + directive = "if" + seen_e = False # SPECIAL (9P2000.e) + for maxval in sorted(inv, reverse=True): + ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" + indent = 1 + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + typ = next(typ for typ in typs if typ.name == "Tswrite") + sz = get_buffer_size(typ, "9P2000.e") + match name: + case "tmsg_max_iov": + maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" + case "tmsg_max_copy": + maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" + case _: + assert False + ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" + ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" + ret += f"\t#else\n" + indent += 1 + ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + ret += "\t#endif\n" + if "9P2000.e" in inv[maxval]: + seen_e = True + directive = "elif" + ret += "#endif\n" + + ret += "\n" + ret += f"struct {idprefix}Tmsg_send_buf {{\n" + ret += f"\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" + ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" + ret += "};\n" + + ret += "\n" + ret += f"struct {idprefix}Rmsg_send_buf {{\n" + ret += f"\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" + ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" + ret += "};\n" + return ret @@ -549,7 +831,10 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: - ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + if member.typ.static_size == 1: # SPECIAL (zerocopy) + ret += f"_validate_size_net(ctx, {member.cnt.name})" + else: + ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" if typ.name == "s": # SPECIAL (string) ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' else: @@ -653,13 +938,15 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o if member.in_versions != typ.in_versions: ret += "{\n" ret += prefix - ret += f"out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" - ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" - if member.typ.static_size == 1: # SPECIAL (string) - # Special-case is that we cast from `char` to `uint8_t`. - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" + if member.typ.static_size == 1: # SPECIAL (string, zerocopy) + ret += f"out->{member.name} = (char *)&ctx->net_bytes[ctx->net_offset];\n" + ret += ( + f"{prefix}ctx->net_offset += out->{member.cnt.name};\n" + ) else: + ret += f"out->{member.name} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" if member.in_versions != typ.in_versions: ret += "\t}\n" @@ -675,139 +962,346 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += """ /* marshal_* ******************************************************************/ -LM_ALWAYS_INLINE static bool _marshal_too_large(struct _marshal_ctx *ctx) { -\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", -\t\t(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", -\t\tctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), -\t\tctx->ctx->max_msg_size); -\treturn true; -} +""" + ret += c_macro( + "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" + "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" + "\t\tctx->net_iov_cnt++;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" + "\tctx->net_iov_cnt++;\n" + ) + ret += c_macro( + "#define MARSHAL_BYTES(ctx, data, len)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" + "\tctx->net_copied_size += len;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" + ) + ret += c_macro( + "#define MARSHAL_U8LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tctx->net_copied[ctx->net_copied_size] = val;\n" + "\tctx->net_copied_size += 1;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" + ) + ret += c_macro( + "#define MARSHAL_U16LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 2;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" + ) + ret += c_macro( + "#define MARSHAL_U32LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 4;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" + ) + ret += c_macro( + "#define MARSHAL_U64LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 8;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" + ) -LM_ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { -\tif (ctx->net_offset + 1 > ctx->ctx->max_msg_size) -\t\treturn _marshal_too_large(ctx); -\tctx->net_bytes[ctx->net_offset] = *val; -\tctx->net_offset += 1; -\treturn false; -} + class OffsetExpr: + static: int + cond: dict[frozenset[str], "OffsetExpr"] + rep: list[tuple[Path, "OffsetExpr"]] + + def __init__(self) -> None: + self.static = 0 + self.rep = [] + self.cond = {} + + def add(self, other: "OffsetExpr") -> None: + self.static += other.static + self.rep += other.rep + for k, v in other.cond.items(): + if k in self.cond: + self.cond[k].add(v) + else: + self.cond[k] = v + + def gen_c( + self, + dsttyp: str, + dstvar: str, + root: str, + indent_depth: int, + loop_depth: int, + ) -> str: + oneline: list[str] = [] + multiline = "" + if self.static: + oneline.append(str(self.static)) + for cnt, sub in self.rep: + if not sub.cond and not sub.rep: + if sub.static == 1: + oneline.append(cnt.c_str(root)) + else: + oneline.append(f"({cnt.c_str(root)})*{sub.static}") + continue + loopvar = chr(ord("i") + loop_depth) + multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" + multiline += sub.gen_c( + "", dstvar, root, indent_depth + 1, loop_depth + 1 + ) + multiline += f"{'\t'*indent_depth}}}\n" + for vers, sub in self.cond.items(): + multiline += ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) + multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n" + multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) + multiline += f"{'\t'*indent_depth}}}\n" + multiline += ifdef_pop(indent_depth) + if dsttyp: + if not oneline: + oneline.append("0") + ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" + elif oneline: + ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" + ret += multiline + return ret -LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { -\tif (ctx->net_offset + 2 > ctx->ctx->max_msg_size) -\t\treturn _marshal_too_large(ctx); -\tuint16le_encode(&ctx->net_bytes[ctx->net_offset], *val); -\tctx->net_offset += 2; -\treturn false; -} + type OffsetExprRecursion = typing.Callable[[Path], WalkCmd] -LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { -\tif (ctx->net_offset + 4 > ctx->ctx->max_msg_size) -\t\treturn true; -\tuint32le_encode(&ctx->net_bytes[ctx->net_offset], *val); -\tctx->net_offset += 4; -\treturn false; -} + def get_offset_expr(typ: idl.Type, recurse: OffsetExprRecursion) -> OffsetExpr: + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret = OffsetExpr() + ret.static = typ.static_size + return ret -LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { -\tif (ctx->net_offset + 8 > ctx->ctx->max_msg_size) -\t\treturn true; -\tuint64le_encode(&ctx->net_bytes[ctx->net_offset], *val); -\tctx->net_offset += 8; -\treturn false; -} -""" - for typ in topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used + stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]] + + def pop_root() -> None: + assert False + + def pop_cond() -> None: + nonlocal stack + key = frozenset(stack[-1][0].elems[-1].in_versions) + if key in stack[-2][1].cond: + stack[-2][1].cond[key].add(stack[-1][1]) + else: + stack[-2][1].cond[key] = stack[-1][1] + stack = stack[:-1] + + def pop_rep() -> None: + nonlocal stack + member_path = stack[-1][0] + member = member_path.elems[-1] + assert member.cnt + cnt_path = member_path.parent().add(member.cnt) + stack[-2][1].rep.append((cnt_path, stack[-1][1])) + stack = stack[:-1] + + def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]: + nonlocal recurse + + ret = recurse(path) + if ret != WalkCmd.KEEP_GOING: + return ret, None + + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + stack[-1][2]() + + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + stack.append((path, OffsetExpr(), pop_cond)) + if child.cnt: + stack.append((path, OffsetExpr(), pop_rep)) + if not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + stack[-1][1].static += child.typ.static_size + return ret, pop + + stack = [(Path(typ), OffsetExpr(), pop_root)] + walk(typ, handle) + return stack[0][1] + + def go_to_end(path: Path) -> WalkCmd: + return WalkCmd.KEEP_GOING + + def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: + def ret(path: Path) -> WalkCmd: + if len(path.elems) == 1 and path.elems[0].name == name: + return WalkCmd.ABORT + return WalkCmd.KEEP_GOING + + return ret + + for typ in typs: + if not ( + isinstance(typ, idl.Message) or typ.name == "stat" + ): # SPECIAL (include stat) + continue + assert isinstance(typ, idl.Struct) ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n" - match typ: - case idl.Number(): - ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n" - case idl.Bitfield(): - ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n" - ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n" - case idl.Struct(): - if len(typ.members) == 0: - ret += "\treturn false;\n" - ret += "}\n" - continue + ret += f"static bool marshal_{typ.name}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" - # Pass 1 - declare offset variables - mark_offset = set() - for member in typ.members: - if member.val: - if member.name not in mark_offset: - ret += f"\tuint32_t _{member.name}_offset;\n" - mark_offset.add(member.name) - for tok in member.val.tokens: - if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"): - if tok.name[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" - mark_offset.add(tok.name[1:]) + # Pass 1 - check size + max_size = max(typ.max_size(v) for v in typ.in_versions) - # Pass 2 - main pass - ret += "\treturn false\n" - for member in typ.members: - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += "\t || " - if member.in_versions != typ.in_versions: - ret += "( " + c_ver_cond(member.in_versions) + " && " - if member.name in mark_offset: - ret += f"({{ _{member.name}_offset = ctx->net_offset; " - if member.cnt: - ret += "({ bool err = false;\n" - ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n" - ret += "\t \terr = " - if member.typ.static_size == 1: # SPECIAL (string) - # Special-case is that we cast from `char` to `uint8_t`. - ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n" - else: - ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n" - ret += f"\t err; }})" - elif member.val: - # Just increment net_offset, don't actually marshal anything (yet). - assert member.static_size - ret += ( - f"({{ ctx->net_offset += {member.static_size}; false; }})" - ) + if max_size > u32max: # SPECIAL (9P2000.e) + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint64_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" + else: + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint32_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" + if isinstance(typ, idl.Message): # SPECIAL (disable for stat) + ret += f'\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%PRIu32)",\n' + ret += f'\t\t\t"{typ.name}",\n' + ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' + ret += "\t\t\tctx->ctx->max_msg_size);\n" + ret += "\t\treturn true;\n" + ret += "\t}\n" + + # Pass 2 - write data + ifdef_depth = 1 + stack: list[tuple[Path, bool]] = [(Path(typ), False)] + + def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + ret += f"{'\t'*(len(stack)-1)}}}\n" + if stack[-1][1]: + ifdef_depth -= 1 + ret += ifdef_pop(ifdef_depth) + stack = stack[:-1] + + loopdepth = sum(1 for elem in path.elems if elem.cnt) + struct = path.elems[-1].typ if path.elems else path.root + if isinstance(struct, idl.Struct): + offsets: list[str] = [] + for member in struct.members: + if not member.val: + continue + for tok in member.val.tokens: + if not isinstance(tok, idl.ExprSym): + continue + if tok.name == "end" or tok.name.startswith("&"): + if tok.name not in offsets: + offsets.append(tok.name) + for name in offsets: + name_prefix = "offsetof_" + "".join( + m.name + "_" for m in path.elems + ) + if name == "end": + if not path.elems: + nonlocal max_size + if max_size > u32max: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" + else: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" + continue + recurse: OffsetExprRecursion = go_to_end else: - ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" - if member.name in mark_offset: - ret += "; })" - if member.in_versions != typ.in_versions: - ret += " )" - ret += "\n" - - # Pass 3 - marshal ,val= members - for member in typ.members: - if member.val: - assert member.static_size - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + assert name.startswith("&") + name = name[1:] + recurse = go_to_tok(name) + expr = get_offset_expr(struct, recurse) + expr_prefix = path.c_str("val->", loopdepth) + if not expr_prefix.endswith(">"): + expr_prefix += "." + ret += expr.gen_c( + "uint32_t", + name_prefix + name, + expr_prefix, + len(stack), + loopdepth, + ) + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + ret += ifdef_push(ifdef_depth + 1, c_ver_ifdef(child.in_versions)) + ifdef_depth += 1 + ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n" + stack.append((path, True)) + if child.cnt: + cnt_path = path.parent().add(child.cnt) + if child.typ.static_size == 1: # SPECIAL (zerocopy) + if path.root.name == "stat": # SPECIAL (stat) + ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + else: + ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + return WalkCmd.KEEP_GOING, pop + loopvar = chr(ord("i") + loopdepth - 1) + ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" + stack.append((path, False)) + if not isinstance(child.typ, idl.Struct): + if child.val: def lookup_sym(sym: str) -> str: - match sym: - case "end": - return "ctx->net_offset" - case _: - assert sym.startswith("&") - return f"_{sym[1:]}_offset" - - if member.static_size == 1: - ret += f"\t || ({{ ctx->net_bytes[_{member.name}_offset] = {c_expr(member.val, lookup_sym)}; false; }})\n" - else: - ret += f"\t || ({{ uint{member.static_size*8}le_encode(&ctx->net_bytes[_{member.name}_offset], {c_expr(member.val, lookup_sym)}); false; }})\n" + nonlocal path + if sym.startswith("&"): + sym = sym[1:] + return ( + "offsetof_" + + "".join(m.name + "_" for m in path.elems[:-1]) + + sym + ) + + val = c_expr(child.val, lookup_sym) + else: + val = path.c_str("val->") + if isinstance(child.typ, idl.Bitfield): + val += f" & {child.typ.name}_masks[ctx->ctx->version]" + ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" + return WalkCmd.KEEP_GOING, pop - ret += ifdef_pop(1) - ret += "\t ;\n" + walk(typ, handle) + + ret += "\treturn false;\n" ret += "}\n" ret += ifdef_pop(0) # function tables ########################################################## ret += """ /* function tables ************************************************************/ - """ + + ret += "\n" + ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" + rerror = next(typ for typ in typs if typ.name == "Rerror") + ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) + for ver in sorted(versions): + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n" + ret += ifdef_pop(0) + ret += "};\n" + + ret += "\n" ret += c_macro( f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" |