# lib9p/protogen/c_marshal.py - Generate C marshal functions # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import typing import idl from . import c9util, cutil, idlutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # pylint: disable=unused-variable __all__ = ["gen_c_marshal"] # get_offset_expr() ############################################################ class OffsetExpr: static: int cond: dict[frozenset[str], "OffsetExpr"] rep: list[tuple[idlutil.Path, "OffsetExpr"]] def __init__(self) -> None: self.static = 0 self.rep = [] self.cond = {} def add(self, other: "OffsetExpr") -> None: self.static += other.static self.rep += other.rep for k, v in other.cond.items(): if k in self.cond: self.cond[k].add(v) else: self.cond[k] = v def gen_c( self, dsttyp: str, dstvar: str, root: str, indent_depth: int, loop_depth: int, ) -> str: oneline: list[str] = [] multiline = "" if self.static: oneline.append(str(self.static)) for cnt, sub in self.rep: if not sub.cond and not sub.rep: if sub.static == 1: oneline.append(cnt.c_str(root)) else: oneline.append(f"({cnt.c_str(root)})*{sub.static}") continue loopvar = chr(ord("i") + loop_depth) multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth + 1) multiline += f"{'\t'*indent_depth}}}\n" for vers, sub in self.cond.items(): multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) multiline += f"{'\t'*indent_depth}}}\n" multiline += cutil.ifdef_pop(indent_depth) ret = "" if dsttyp: if not oneline: oneline.append("0") ret += f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" elif oneline: ret += f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" ret += multiline return ret type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd] def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: if not isinstance(typ, idl.Struct): assert typ.static_size ret = OffsetExpr() ret.static = typ.static_size return ret class ExprStackItem(typing.NamedTuple): path: idlutil.Path expr: OffsetExpr pop: typing.Callable[[], None] expr_stack: list[ExprStackItem] def pop_root() -> None: assert False def pop_cond() -> None: nonlocal expr_stack key = frozenset(expr_stack[-1].path.elems[-1].in_versions) if key in expr_stack[-2].expr.cond: expr_stack[-2].expr.cond[key].add(expr_stack[-1].expr) else: expr_stack[-2].expr.cond[key] = expr_stack[-1].expr expr_stack = expr_stack[:-1] def pop_rep() -> None: nonlocal expr_stack member_path = expr_stack[-1].path member = member_path.elems[-1] assert member.cnt cnt_path = member_path.parent().add(member.cnt) expr_stack[-2].expr.rep.append((cnt_path, expr_stack[-1].expr)) expr_stack = expr_stack[:-1] def handle( path: idlutil.Path, ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]: nonlocal recurse ret = recurse(path) if ret != idlutil.WalkCmd.KEEP_GOING: return ret, None nonlocal expr_stack expr_stack_len = len(expr_stack) def pop() -> None: nonlocal expr_stack nonlocal expr_stack_len while len(expr_stack) > expr_stack_len: expr_stack[-1].pop() if path.elems: child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: expr_stack.append( ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_cond) ) if child.cnt: expr_stack.append( ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_rep) ) if not isinstance(child.typ, idl.Struct): assert child.typ.static_size expr_stack[-1].expr.static += child.typ.static_size return ret, pop expr_stack = [ ExprStackItem(path=idlutil.Path(typ), expr=OffsetExpr(), pop=pop_root) ] idlutil.walk(typ, handle) return expr_stack[0].expr def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd: return idlutil.WalkCmd.KEEP_GOING def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]: def ret(path: idlutil.Path) -> idlutil.WalkCmd: if len(path.elems) == 1 and path.elems[0].membname == name: return idlutil.WalkCmd.ABORT return idlutil.WalkCmd.KEEP_GOING return ret # Generate .c ################################################################## def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str: ret = """ /* marshal_* ******************************************************************/ """ ret += cutil.macro( "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" "\t\tctx->net_iov_cnt++;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" "\tctx->net_iov_cnt++;\n" ) ret += cutil.macro( "#define MARSHAL_BYTES(ctx, data, len)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" "\tctx->net_copied_size += len;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" ) ret += cutil.macro( "#define MARSHAL_U8LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tctx->net_copied[ctx->net_copied_size] = val;\n" "\tctx->net_copied_size += 1;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" ) ret += cutil.macro( "#define MARSHAL_U16LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" "\tctx->net_copied_size += 2;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" ) ret += cutil.macro( "#define MARSHAL_U32LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" "\tctx->net_copied_size += 4;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" ) ret += cutil.macro( "#define MARSHAL_U64LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" "\tctx->net_copied_size += 8;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" ) class IndentLevel(typing.NamedTuple): ifdef: bool # whether this is both `{` and `#if`, or just `{` indent_stack: list[IndentLevel] def ifdef_lvl() -> int: return sum(1 if lvl.ifdef else 0 for lvl in indent_stack) def indent_lvl() -> int: return len(indent_stack) max_size: int def handle( path: idlutil.Path, ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: nonlocal ret nonlocal indent_stack nonlocal max_size indent_stack_len = len(indent_stack) def pop() -> None: nonlocal ret nonlocal indent_stack nonlocal indent_stack_len while len(indent_stack) > indent_stack_len: if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef: break ret += f"{'\t'*(indent_lvl()-1)}}}\n" if indent_stack.pop().ifdef: ret += cutil.ifdef_pop(ifdef_lvl()) loopdepth = sum(1 for elem in path.elems if elem.cnt) struct = path.elems[-1].typ if path.elems else path.root if isinstance(struct, idl.Struct): offsets: list[str] = [] for member in struct.members: if not member.val: continue for tok in member.val.tokens: if not isinstance(tok, idl.ExprSym): continue if tok.symname == "end" or tok.symname.startswith("&"): if tok.symname not in offsets: offsets.append(tok.symname) for name in offsets: name_prefix = f"offsetof{''.join('_'+m.membname for m in path.elems)}_" if name == "end": if not path.elems: if max_size > cutil.UINT32_MAX: ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" else: ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = needed_size;\n" continue recurse: OffsetExprRecursion = go_to_end else: assert name.startswith("&") name = name[1:] recurse = go_to_tok(name) expr = get_offset_expr(struct, recurse) expr_prefix = path.c_str("val->", loopdepth) if not expr_prefix.endswith(">"): expr_prefix += "." ret += expr.gen_c( "uint32_t", name_prefix + name, expr_prefix, indent_lvl(), loopdepth, ) if not path.elems: return idlutil.WalkCmd.KEEP_GOING, pop child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: if line := cutil.ifdef_push( ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions) ): ret += line ret += ( f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n" ) indent_stack.append(IndentLevel(ifdef=True)) if child.cnt: cnt_path = path.parent().add(child.cnt) if child.typ.static_size == 1: # SPECIAL (zerocopy) if path.root.typname == "stat": # SPECIAL (stat) ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" else: ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" return idlutil.WalkCmd.KEEP_GOING, pop loopvar = chr(ord("i") + loopdepth - 1) ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" indent_stack.append(IndentLevel(ifdef=False)) if not isinstance(child.typ, idl.Struct): if child.val: def lookup_sym(sym: str) -> str: nonlocal path if sym.startswith("&"): sym = sym[1:] return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}" val = c9util.idl_expr(child.val, lookup_sym) else: val = path.c_str("val->") if isinstance(child.typ, idl.Bitfield): val += f" & {child.typ.typname}_masks[ctx->ctx->version]" ret += f"{'\t'*indent_lvl()}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" return idlutil.WalkCmd.KEEP_GOING, pop for typ in typs: if not ( isinstance(typ, idl.Message) or typ.typname == "stat" ): # SPECIAL (include stat) continue assert isinstance(typ, idl.Struct) ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) ret += get_offset_expr(typ, go_to_end).gen_c( "uint64_t", "needed_size", "val->", 1, 0 ) ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" else: ret += get_offset_expr(typ, go_to_end).gen_c( "uint32_t", "needed_size", "val->", 1, 0 ) ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" if isinstance(typ, idl.Message): # SPECIAL (disable for stat) ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' ret += f'\t\t\t"{typ.typname}",\n' ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' ret += "\t\t\tctx->ctx->max_msg_size);\n" ret += "\t\treturn true;\n" ret += "\t}\n" # Pass 2 - write data indent_stack = [IndentLevel(ifdef=True)] idlutil.walk(typ, handle) while len(indent_stack) > 1: ret += f"{'\t'*(indent_lvl()-1)}}}\n" if indent_stack.pop().ifdef: ret += cutil.ifdef_pop(ifdef_lvl()) # Return ret += "\treturn false;\n" ret += "}\n" ret += cutil.ifdef_pop(0) return ret