summaryrefslogtreecommitdiff
path: root/lib9p/protogen/c_marshal.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/protogen/c_marshal.py')
-rw-r--r--lib9p/protogen/c_marshal.py386
1 files changed, 386 insertions, 0 deletions
diff --git a/lib9p/protogen/c_marshal.py b/lib9p/protogen/c_marshal.py
new file mode 100644
index 0000000..74b64f5
--- /dev/null
+++ b/lib9p/protogen/c_marshal.py
@@ -0,0 +1,386 @@
+# lib9p/protogen/c_marshal.py - Generate C marshal functions
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import typing
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c_marshal"]
+
+# get_offset_expr() ############################################################
+
+
+class OffsetExpr:
+ static: int
+ cond: dict[frozenset[str], "OffsetExpr"]
+ rep: list[tuple[idlutil.Path, "OffsetExpr"]]
+
+ def __init__(self) -> None:
+ self.static = 0
+ self.rep = []
+ self.cond = {}
+
+ def add(self, other: "OffsetExpr") -> None:
+ self.static += other.static
+ self.rep += other.rep
+ for k, v in other.cond.items():
+ if k in self.cond:
+ self.cond[k].add(v)
+ else:
+ self.cond[k] = v
+
+ def gen_c(
+ self,
+ dsttyp: str,
+ dstvar: str,
+ root: str,
+ indent_depth: int,
+ loop_depth: int,
+ ) -> str:
+ oneline: list[str] = []
+ multiline = ""
+ if self.static:
+ oneline.append(str(self.static))
+ for cnt, sub in self.rep:
+ if not sub.cond and not sub.rep:
+ if sub.static == 1:
+ oneline.append(cnt.c_str(root))
+ else:
+ oneline.append(f"({cnt.c_str(root)})*{sub.static}")
+ continue
+ loopvar = chr(ord("i") + loop_depth)
+ multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
+ multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth + 1)
+ multiline += f"{'\t'*indent_depth}}}\n"
+ for vers, sub in self.cond.items():
+ multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
+ multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
+ multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
+ multiline += f"{'\t'*indent_depth}}}\n"
+ multiline += cutil.ifdef_pop(indent_depth)
+ ret = ""
+ if dsttyp:
+ if not oneline:
+ oneline.append("0")
+ ret += f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
+ elif oneline:
+ ret += f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
+ ret += multiline
+ return ret
+
+
+type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]
+
+
+def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
+ if not isinstance(typ, idl.Struct):
+ assert typ.static_size
+ ret = OffsetExpr()
+ ret.static = typ.static_size
+ return ret
+
+ class ExprStackItem(typing.NamedTuple):
+ path: idlutil.Path
+ expr: OffsetExpr
+ pop: typing.Callable[[], None]
+
+ expr_stack: list[ExprStackItem]
+
+ def pop_root() -> None:
+ assert False
+
+ def pop_cond() -> None:
+ nonlocal expr_stack
+ key = frozenset(expr_stack[-1].path.elems[-1].in_versions)
+ if key in expr_stack[-2].expr.cond:
+ expr_stack[-2].expr.cond[key].add(expr_stack[-1].expr)
+ else:
+ expr_stack[-2].expr.cond[key] = expr_stack[-1].expr
+ expr_stack = expr_stack[:-1]
+
+ def pop_rep() -> None:
+ nonlocal expr_stack
+ member_path = expr_stack[-1].path
+ member = member_path.elems[-1]
+ assert member.cnt
+ cnt_path = member_path.parent().add(member.cnt)
+ expr_stack[-2].expr.rep.append((cnt_path, expr_stack[-1].expr))
+ expr_stack = expr_stack[:-1]
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
+ nonlocal recurse
+
+ ret = recurse(path)
+ if ret != idlutil.WalkCmd.KEEP_GOING:
+ return ret, None
+
+ nonlocal expr_stack
+ expr_stack_len = len(expr_stack)
+
+ def pop() -> None:
+ nonlocal expr_stack
+ nonlocal expr_stack_len
+ while len(expr_stack) > expr_stack_len:
+ expr_stack[-1].pop()
+
+ if path.elems:
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ expr_stack.append(
+ ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_cond)
+ )
+ if child.cnt:
+ expr_stack.append(
+ ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_rep)
+ )
+ if not isinstance(child.typ, idl.Struct):
+ assert child.typ.static_size
+ expr_stack[-1].expr.static += child.typ.static_size
+ return ret, pop
+
+ expr_stack = [
+ ExprStackItem(path=idlutil.Path(typ), expr=OffsetExpr(), pop=pop_root)
+ ]
+ idlutil.walk(typ, handle)
+ return expr_stack[0].expr
+
+
+def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
+ return idlutil.WalkCmd.KEEP_GOING
+
+
+def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
+ def ret(path: idlutil.Path) -> idlutil.WalkCmd:
+ if len(path.elems) == 1 and path.elems[0].membname == name:
+ return idlutil.WalkCmd.ABORT
+ return idlutil.WalkCmd.KEEP_GOING
+
+ return ret
+
+
+# Generate .c ##################################################################
+
+
+def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
+ ret = """
+/* marshal_* ******************************************************************/
+
+"""
+ ret += cutil.macro(
+ "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
+ "\tif (ret->net_iov[ret->net_iov_cnt-1].iov_len)\n"
+ "\t\tret->net_iov_cnt++;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_base = data;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len = len;\n"
+ "\tret->net_iov_cnt++;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_BYTES(ctx, data, len)\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tmemcpy(&ret->net_copied[ret->net_copied_size], data, len);\n"
+ "\tret->net_copied_size += len;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += len;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U8LE(ctx, val)\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tret->net_copied[ret->net_copied_size] = val;\n"
+ "\tret->net_copied_size += 1;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 1;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U16LE(ctx, val)\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tuint16le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
+ "\tret->net_copied_size += 2;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 2;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U32LE(ctx, val)\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tuint32le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
+ "\tret->net_copied_size += 4;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 4;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U64LE(ctx, val)\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tuint64le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
+ "\tret->net_copied_size += 8;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 8;\n"
+ )
+
+ class IndentLevel(typing.NamedTuple):
+ ifdef: bool # whether this is both `{` and `#if`, or just `{`
+
+ indent_stack: list[IndentLevel]
+
+ def ifdef_lvl() -> int:
+ return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
+
+ def indent_lvl() -> int:
+ return len(indent_stack)
+
+ max_size: int
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal max_size
+ indent_stack_len = len(indent_stack)
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal indent_stack_len
+ while len(indent_stack) > indent_stack_len:
+ if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
+ break
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ struct = path.elems[-1].typ if path.elems else path.root
+ if isinstance(struct, idl.Struct):
+ offsets: list[str] = []
+ for member in struct.members:
+ if not member.val:
+ continue
+ for tok in member.val.tokens:
+ if not isinstance(tok, idl.ExprSym):
+ continue
+ if tok.symname == "end" or tok.symname.startswith("&"):
+ if tok.symname not in offsets:
+ offsets.append(tok.symname)
+ for name in offsets:
+ name_prefix = f"offsetof{''.join('_'+m.membname for m in path.elems)}_"
+ if name == "end":
+ if not path.elems:
+ if max_size > cutil.UINT32_MAX:
+ ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
+ else:
+ ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = needed_size;\n"
+ continue
+ recurse: OffsetExprRecursion = go_to_end
+ else:
+ assert name.startswith("&")
+ name = name[1:]
+ recurse = go_to_tok(name)
+ expr = get_offset_expr(struct, recurse)
+ expr_prefix = path.c_str("val->", loopdepth)
+ if not expr_prefix.endswith(">"):
+ expr_prefix += "."
+ ret += expr.gen_c(
+ "uint32_t",
+ name_prefix + name,
+ expr_prefix,
+ indent_lvl(),
+ loopdepth,
+ )
+ if not path.elems:
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ if line := cutil.ifdef_push(
+ ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
+ ):
+ ret += line
+ ret += (
+ f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+ )
+ indent_stack.append(IndentLevel(ifdef=True))
+ if child.cnt:
+ cnt_path = path.parent().add(child.cnt)
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ if path.root.typname == "stat": # SPECIAL (stat)
+ ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ else:
+ ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+ loopvar = chr(ord("i") + loopdepth - 1)
+ ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
+ indent_stack.append(IndentLevel(ifdef=False))
+ if not isinstance(child.typ, idl.Struct):
+ if child.val:
+
+ def lookup_sym(sym: str) -> str:
+ nonlocal path
+ if sym.startswith("&"):
+ sym = sym[1:]
+ return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
+
+ val = c9util.idl_expr(child.val, lookup_sym)
+ else:
+ val = path.c_str("val->")
+ if isinstance(child.typ, idl.Bitfield):
+ val += f" & {child.typ.typname}_masks[ctx->version]"
+ ret += f"{'\t'*indent_lvl()}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ for typ in typs:
+ if not (
+ isinstance(typ, idl.Message) or typ.typname == "stat"
+ ): # SPECIAL (include stat)
+ continue
+ assert isinstance(typ, idl.Struct)
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"static bool marshal_{typ.typname}(struct lib9p_ctx *ctx, {c9util.typename(typ)} *val, struct _marshal_ret *ret) {{\n"
+
+ # Pass 1 - check size
+ max_size = max(typ.max_size(v) for v in typ.in_versions)
+
+ if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e)
+ ret += get_offset_expr(typ, go_to_end).gen_c(
+ "uint64_t", "needed_size", "val->", 1, 0
+ )
+ ret += "\tif (needed_size > (uint64_t)(ctx->max_msg_size)) {\n"
+ else:
+ ret += get_offset_expr(typ, go_to_end).gen_c(
+ "uint32_t", "needed_size", "val->", 1, 0
+ )
+ ret += "\tif (needed_size > ctx->max_msg_size) {\n"
+ if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
+ ret += '\t\tlib9p_errorf(ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
+ ret += f'\t\t\t"{typ.typname}",\n'
+ ret += f'\t\t\tctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
+ ret += "\t\t\tctx->max_msg_size);\n"
+ ret += "\t\treturn true;\n"
+ ret += "\t}\n"
+
+ # Pass 2 - write data
+ indent_stack = [IndentLevel(ifdef=True)]
+ idlutil.walk(typ, handle)
+ while len(indent_stack) > 1:
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ # Return
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ ret += cutil.ifdef_pop(0)
+ return ret