summaryrefslogtreecommitdiff
path: root/lib9p/protogen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/protogen')
-rw-r--r--lib9p/protogen/c.py18
-rw-r--r--lib9p/protogen/c9util.py8
-rw-r--r--lib9p/protogen/c_marshal.py583
-rw-r--r--lib9p/protogen/c_unmarshal.py169
-rw-r--r--lib9p/protogen/c_validate.py394
-rw-r--r--lib9p/protogen/cutil.py2
6 files changed, 676 insertions, 498 deletions
diff --git a/lib9p/protogen/c.py b/lib9p/protogen/c.py
index a7e1773..5e67939 100644
--- a/lib9p/protogen/c.py
+++ b/lib9p/protogen/c.py
@@ -76,12 +76,12 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
ret += "#endif\n"
ret += "\n"
ret += "/**\n"
- ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n"
- ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n"
+ ret += f" * is_ver(ctx, ver) is essentially `(ctx->version == {c9util.Ident('VER_')}##ver)`, but\n"
+ ret += f" * compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n"
ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n"
ret += " * several version checks together.\n"
ret += " */\n"
- ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n"
+ ret += "#define is_ver(ctx, ver) _is_ver_##ver((ctx)->version)\n"
# strings ##################################################################
ret += f"""
@@ -185,14 +185,14 @@ const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] =
)
ret += f"""
-LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{
-\treturn validate_stat(ctx);
+LM_FLATTEN ssize_t {c9util.ident('_stat_validate')}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{
+\treturn validate_stat(ctx, net_size, net_bytes, ret_net_size);
}}
-LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{
-\tunmarshal_stat(ctx, out);
+LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct lib9p_ctx *ctx, uint8_t *net_bytes, void *out) {{
+\tunmarshal_stat(ctx, net_bytes, out);
}}
-LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{
-\treturn marshal_stat(ctx, val);
+LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct lib9p_ctx *ctx, struct {c9util.ident('stat')} *val, struct _marshal_ret *ret) {{
+\treturn marshal_stat(ctx, val, ret);
}}
"""
diff --git a/lib9p/protogen/c9util.py b/lib9p/protogen/c9util.py
index f9c49fc..e7ad999 100644
--- a/lib9p/protogen/c9util.py
+++ b/lib9p/protogen/c9util.py
@@ -107,11 +107,3 @@ def idl_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
case _:
assert False
return " ".join(ret)
-
-
-def arg_used(arg: str) -> str:
- return arg
-
-
-def arg_unused(arg: str) -> str:
- return f"LM_UNUSED({arg})"
diff --git a/lib9p/protogen/c_marshal.py b/lib9p/protogen/c_marshal.py
index 152206d..74b64f5 100644
--- a/lib9p/protogen/c_marshal.py
+++ b/lib9p/protogen/c_marshal.py
@@ -17,6 +17,162 @@ from . import c9util, cutil, idlutil
# pylint: disable=unused-variable
__all__ = ["gen_c_marshal"]
+# get_offset_expr() ############################################################
+
+
+class OffsetExpr:
+ static: int
+ cond: dict[frozenset[str], "OffsetExpr"]
+ rep: list[tuple[idlutil.Path, "OffsetExpr"]]
+
+ def __init__(self) -> None:
+ self.static = 0
+ self.rep = []
+ self.cond = {}
+
+ def add(self, other: "OffsetExpr") -> None:
+ self.static += other.static
+ self.rep += other.rep
+ for k, v in other.cond.items():
+ if k in self.cond:
+ self.cond[k].add(v)
+ else:
+ self.cond[k] = v
+
+ def gen_c(
+ self,
+ dsttyp: str,
+ dstvar: str,
+ root: str,
+ indent_depth: int,
+ loop_depth: int,
+ ) -> str:
+ oneline: list[str] = []
+ multiline = ""
+ if self.static:
+ oneline.append(str(self.static))
+ for cnt, sub in self.rep:
+ if not sub.cond and not sub.rep:
+ if sub.static == 1:
+ oneline.append(cnt.c_str(root))
+ else:
+ oneline.append(f"({cnt.c_str(root)})*{sub.static}")
+ continue
+ loopvar = chr(ord("i") + loop_depth)
+ multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
+ multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth + 1)
+ multiline += f"{'\t'*indent_depth}}}\n"
+ for vers, sub in self.cond.items():
+ multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
+ multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
+ multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
+ multiline += f"{'\t'*indent_depth}}}\n"
+ multiline += cutil.ifdef_pop(indent_depth)
+ ret = ""
+ if dsttyp:
+ if not oneline:
+ oneline.append("0")
+ ret += f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
+ elif oneline:
+ ret += f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
+ ret += multiline
+ return ret
+
+
+type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]
+
+
+def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
+ if not isinstance(typ, idl.Struct):
+ assert typ.static_size
+ ret = OffsetExpr()
+ ret.static = typ.static_size
+ return ret
+
+ class ExprStackItem(typing.NamedTuple):
+ path: idlutil.Path
+ expr: OffsetExpr
+ pop: typing.Callable[[], None]
+
+ expr_stack: list[ExprStackItem]
+
+ def pop_root() -> None:
+ assert False
+
+ def pop_cond() -> None:
+ nonlocal expr_stack
+ key = frozenset(expr_stack[-1].path.elems[-1].in_versions)
+ if key in expr_stack[-2].expr.cond:
+ expr_stack[-2].expr.cond[key].add(expr_stack[-1].expr)
+ else:
+ expr_stack[-2].expr.cond[key] = expr_stack[-1].expr
+ expr_stack = expr_stack[:-1]
+
+ def pop_rep() -> None:
+ nonlocal expr_stack
+ member_path = expr_stack[-1].path
+ member = member_path.elems[-1]
+ assert member.cnt
+ cnt_path = member_path.parent().add(member.cnt)
+ expr_stack[-2].expr.rep.append((cnt_path, expr_stack[-1].expr))
+ expr_stack = expr_stack[:-1]
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
+ nonlocal recurse
+
+ ret = recurse(path)
+ if ret != idlutil.WalkCmd.KEEP_GOING:
+ return ret, None
+
+ nonlocal expr_stack
+ expr_stack_len = len(expr_stack)
+
+ def pop() -> None:
+ nonlocal expr_stack
+ nonlocal expr_stack_len
+ while len(expr_stack) > expr_stack_len:
+ expr_stack[-1].pop()
+
+ if path.elems:
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ expr_stack.append(
+ ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_cond)
+ )
+ if child.cnt:
+ expr_stack.append(
+ ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_rep)
+ )
+ if not isinstance(child.typ, idl.Struct):
+ assert child.typ.static_size
+ expr_stack[-1].expr.static += child.typ.static_size
+ return ret, pop
+
+ expr_stack = [
+ ExprStackItem(path=idlutil.Path(typ), expr=OffsetExpr(), pop=pop_root)
+ ]
+ idlutil.walk(typ, handle)
+ return expr_stack[0].expr
+
+
+def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
+ return idlutil.WalkCmd.KEEP_GOING
+
+
+def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
+ def ret(path: idlutil.Path) -> idlutil.WalkCmd:
+ if len(path.elems) == 1 and path.elems[0].membname == name:
+ return idlutil.WalkCmd.ABORT
+ return idlutil.WalkCmd.KEEP_GOING
+
+ return ret
+
+
+# Generate .c ##################################################################
+
def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
ret = """
@@ -25,188 +181,164 @@ def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
"""
ret += cutil.macro(
"#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
- "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n"
- "\t\tctx->net_iov_cnt++;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n"
- "\tctx->net_iov_cnt++;\n"
+ "\tif (ret->net_iov[ret->net_iov_cnt-1].iov_len)\n"
+ "\t\tret->net_iov_cnt++;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_base = data;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len = len;\n"
+ "\tret->net_iov_cnt++;\n"
)
ret += cutil.macro(
"#define MARSHAL_BYTES(ctx, data, len)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n"
- "\tctx->net_copied_size += len;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tmemcpy(&ret->net_copied[ret->net_copied_size], data, len);\n"
+ "\tret->net_copied_size += len;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += len;\n"
)
ret += cutil.macro(
"#define MARSHAL_U8LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tctx->net_copied[ctx->net_copied_size] = val;\n"
- "\tctx->net_copied_size += 1;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tret->net_copied[ret->net_copied_size] = val;\n"
+ "\tret->net_copied_size += 1;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 1;\n"
)
ret += cutil.macro(
"#define MARSHAL_U16LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
- "\tctx->net_copied_size += 2;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tuint16le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
+ "\tret->net_copied_size += 2;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 2;\n"
)
ret += cutil.macro(
"#define MARSHAL_U32LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
- "\tctx->net_copied_size += 4;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tuint32le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
+ "\tret->net_copied_size += 4;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 4;\n"
)
ret += cutil.macro(
"#define MARSHAL_U64LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
- "\tctx->net_copied_size += 8;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n"
+ "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
+ "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
+ "\tuint64le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
+ "\tret->net_copied_size += 8;\n"
+ "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 8;\n"
)
- class OffsetExpr:
- static: int
- cond: dict[frozenset[str], "OffsetExpr"]
- rep: list[tuple[idlutil.Path, "OffsetExpr"]]
-
- def __init__(self) -> None:
- self.static = 0
- self.rep = []
- self.cond = {}
-
- def add(self, other: "OffsetExpr") -> None:
- self.static += other.static
- self.rep += other.rep
- for k, v in other.cond.items():
- if k in self.cond:
- self.cond[k].add(v)
- else:
- self.cond[k] = v
-
- def gen_c(
- self,
- dsttyp: str,
- dstvar: str,
- root: str,
- indent_depth: int,
- loop_depth: int,
- ) -> str:
- oneline: list[str] = []
- multiline = ""
- if self.static:
- oneline.append(str(self.static))
- for cnt, sub in self.rep:
- if not sub.cond and not sub.rep:
- if sub.static == 1:
- oneline.append(cnt.c_str(root))
- else:
- oneline.append(f"({cnt.c_str(root)})*{sub.static}")
- continue
- loopvar = chr(ord("i") + loop_depth)
- multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
- multiline += sub.gen_c(
- "", dstvar, root, indent_depth + 1, loop_depth + 1
- )
- multiline += f"{'\t'*indent_depth}}}\n"
- for vers, sub in self.cond.items():
- multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
- multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
- multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
- multiline += f"{'\t'*indent_depth}}}\n"
- multiline += cutil.ifdef_pop(indent_depth)
- if dsttyp:
- if not oneline:
- oneline.append("0")
- ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
- elif oneline:
- ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
- ret += multiline
- return ret
-
- type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]
-
- def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
- if not isinstance(typ, idl.Struct):
- assert typ.static_size
- ret = OffsetExpr()
- ret.static = typ.static_size
- return ret
-
- stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]]
-
- def pop_root() -> None:
- assert False
-
- def pop_cond() -> None:
- nonlocal stack
- key = frozenset(stack[-1][0].elems[-1].in_versions)
- if key in stack[-2][1].cond:
- stack[-2][1].cond[key].add(stack[-1][1])
- else:
- stack[-2][1].cond[key] = stack[-1][1]
- stack = stack[:-1]
-
- def pop_rep() -> None:
- nonlocal stack
- member_path = stack[-1][0]
- member = member_path.elems[-1]
- assert member.cnt
- cnt_path = member_path.parent().add(member.cnt)
- stack[-2][1].rep.append((cnt_path, stack[-1][1]))
- stack = stack[:-1]
-
- def handle(
- path: idlutil.Path,
- ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
- nonlocal recurse
-
- ret = recurse(path)
- if ret != idlutil.WalkCmd.KEEP_GOING:
- return ret, None
-
- nonlocal stack
- stack_len = len(stack)
-
- def pop() -> None:
- nonlocal stack
- nonlocal stack_len
- while len(stack) > stack_len:
- stack[-1][2]()
-
- if path.elems:
- child = path.elems[-1]
- parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
- if child.in_versions < parent.in_versions:
- stack.append((path, OffsetExpr(), pop_cond))
- if child.cnt:
- stack.append((path, OffsetExpr(), pop_rep))
- if not isinstance(child.typ, idl.Struct):
- assert child.typ.static_size
- stack[-1][1].static += child.typ.static_size
- return ret, pop
-
- stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)]
- idlutil.walk(typ, handle)
- return stack[0][1]
+ class IndentLevel(typing.NamedTuple):
+ ifdef: bool # whether this is both `{` and `#if`, or just `{`
- def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
- return idlutil.WalkCmd.KEEP_GOING
+ indent_stack: list[IndentLevel]
- def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
- def ret(path: idlutil.Path) -> idlutil.WalkCmd:
- if len(path.elems) == 1 and path.elems[0].membname == name:
- return idlutil.WalkCmd.ABORT
- return idlutil.WalkCmd.KEEP_GOING
+ def ifdef_lvl() -> int:
+ return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
- return ret
+ def indent_lvl() -> int:
+ return len(indent_stack)
+
+ max_size: int
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal max_size
+ indent_stack_len = len(indent_stack)
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal indent_stack_len
+ while len(indent_stack) > indent_stack_len:
+ if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
+ break
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ struct = path.elems[-1].typ if path.elems else path.root
+ if isinstance(struct, idl.Struct):
+ offsets: list[str] = []
+ for member in struct.members:
+ if not member.val:
+ continue
+ for tok in member.val.tokens:
+ if not isinstance(tok, idl.ExprSym):
+ continue
+ if tok.symname == "end" or tok.symname.startswith("&"):
+ if tok.symname not in offsets:
+ offsets.append(tok.symname)
+ for name in offsets:
+ name_prefix = f"offsetof{''.join('_'+m.membname for m in path.elems)}_"
+ if name == "end":
+ if not path.elems:
+ if max_size > cutil.UINT32_MAX:
+ ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
+ else:
+ ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = needed_size;\n"
+ continue
+ recurse: OffsetExprRecursion = go_to_end
+ else:
+ assert name.startswith("&")
+ name = name[1:]
+ recurse = go_to_tok(name)
+ expr = get_offset_expr(struct, recurse)
+ expr_prefix = path.c_str("val->", loopdepth)
+ if not expr_prefix.endswith(">"):
+ expr_prefix += "."
+ ret += expr.gen_c(
+ "uint32_t",
+ name_prefix + name,
+ expr_prefix,
+ indent_lvl(),
+ loopdepth,
+ )
+ if not path.elems:
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ if line := cutil.ifdef_push(
+ ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
+ ):
+ ret += line
+ ret += (
+ f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+ )
+ indent_stack.append(IndentLevel(ifdef=True))
+ if child.cnt:
+ cnt_path = path.parent().add(child.cnt)
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ if path.root.typname == "stat": # SPECIAL (stat)
+ ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ else:
+ ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+ loopvar = chr(ord("i") + loopdepth - 1)
+ ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
+ indent_stack.append(IndentLevel(ifdef=False))
+ if not isinstance(child.typ, idl.Struct):
+ if child.val:
+
+ def lookup_sym(sym: str) -> str:
+ nonlocal path
+ if sym.startswith("&"):
+ sym = sym[1:]
+ return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
+
+ val = c9util.idl_expr(child.val, lookup_sym)
+ else:
+ val = path.c_str("val->")
+ if isinstance(child.typ, idl.Bitfield):
+ val += f" & {child.typ.typname}_masks[ctx->version]"
+ ret += f"{'\t'*indent_lvl()}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
for typ in typs:
if not (
@@ -216,7 +348,7 @@ def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
assert isinstance(typ, idl.Struct)
ret += "\n"
ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
- ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n"
+ ret += f"static bool marshal_{typ.typname}(struct lib9p_ctx *ctx, {c9util.typename(typ)} *val, struct _marshal_ret *ret) {{\n"
# Pass 1 - check size
max_size = max(typ.max_size(v) for v in typ.in_versions)
@@ -225,132 +357,29 @@ def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
ret += get_offset_expr(typ, go_to_end).gen_c(
"uint64_t", "needed_size", "val->", 1, 0
)
- ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n"
+ ret += "\tif (needed_size > (uint64_t)(ctx->max_msg_size)) {\n"
else:
ret += get_offset_expr(typ, go_to_end).gen_c(
"uint32_t", "needed_size", "val->", 1, 0
)
- ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
+ ret += "\tif (needed_size > ctx->max_msg_size) {\n"
if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
- ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
+ ret += '\t\tlib9p_errorf(ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
ret += f'\t\t\t"{typ.typname}",\n'
- ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
- ret += "\t\t\tctx->ctx->max_msg_size);\n"
+ ret += f'\t\t\tctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
+ ret += "\t\t\tctx->max_msg_size);\n"
ret += "\t\treturn true;\n"
ret += "\t}\n"
# Pass 2 - write data
- ifdef_depth = 1
- stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)]
-
- def handle(
- path: idlutil.Path,
- ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
- nonlocal ret
- nonlocal ifdef_depth
- nonlocal stack
- stack_len = len(stack)
-
- def pop() -> None:
- nonlocal ret
- nonlocal ifdef_depth
- nonlocal stack
- nonlocal stack_len
- while len(stack) > stack_len:
- ret += f"{'\t'*(len(stack)-1)}}}\n"
- if stack[-1][1]:
- ifdef_depth -= 1
- ret += cutil.ifdef_pop(ifdef_depth)
- stack = stack[:-1]
-
- loopdepth = sum(1 for elem in path.elems if elem.cnt)
- struct = path.elems[-1].typ if path.elems else path.root
- if isinstance(struct, idl.Struct):
- offsets: list[str] = []
- for member in struct.members:
- if not member.val:
- continue
- for tok in member.val.tokens:
- if not isinstance(tok, idl.ExprSym):
- continue
- if tok.symname == "end" or tok.symname.startswith("&"):
- if tok.symname not in offsets:
- offsets.append(tok.symname)
- for name in offsets:
- name_prefix = "offsetof_" + "".join(
- m.membname + "_" for m in path.elems
- )
- if name == "end":
- if not path.elems:
- nonlocal max_size
- if max_size > cutil.UINT32_MAX:
- ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
- else:
- ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n"
- continue
- recurse: OffsetExprRecursion = go_to_end
- else:
- assert name.startswith("&")
- name = name[1:]
- recurse = go_to_tok(name)
- expr = get_offset_expr(struct, recurse)
- expr_prefix = path.c_str("val->", loopdepth)
- if not expr_prefix.endswith(">"):
- expr_prefix += "."
- ret += expr.gen_c(
- "uint32_t",
- name_prefix + name,
- expr_prefix,
- len(stack),
- loopdepth,
- )
- if path.elems:
- child = path.elems[-1]
- parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
- if child.in_versions < parent.in_versions:
- ret += cutil.ifdef_push(
- ifdef_depth + 1, c9util.ver_ifdef(child.in_versions)
- )
- ifdef_depth += 1
- ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n"
- stack.append((path, True))
- if child.cnt:
- cnt_path = path.parent().add(child.cnt)
- if child.typ.static_size == 1: # SPECIAL (zerocopy)
- if path.root.typname == "stat": # SPECIAL (stat)
- ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
- else:
- ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
- return idlutil.WalkCmd.KEEP_GOING, pop
- loopvar = chr(ord("i") + loopdepth - 1)
- ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
- stack.append((path, False))
- if not isinstance(child.typ, idl.Struct):
- if child.val:
-
- def lookup_sym(sym: str) -> str:
- nonlocal path
- if sym.startswith("&"):
- sym = sym[1:]
- return (
- "offsetof_"
- + "".join(m.membname + "_" for m in path.elems[:-1])
- + sym
- )
-
- val = c9util.idl_expr(child.val, lookup_sym)
- else:
- val = path.c_str("val->")
- if isinstance(child.typ, idl.Bitfield):
- val += f" & {child.typ.typname}_masks[ctx->ctx->version]"
- ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
- return idlutil.WalkCmd.KEEP_GOING, pop
-
+ indent_stack = [IndentLevel(ifdef=True)]
idlutil.walk(typ, handle)
- del handle
- del stack
- del max_size
+ while len(indent_stack) > 1:
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+ # Return
ret += "\treturn false;\n"
ret += "}\n"
ret += cutil.ifdef_pop(0)
diff --git a/lib9p/protogen/c_unmarshal.py b/lib9p/protogen/c_unmarshal.py
index e17f456..018d750 100644
--- a/lib9p/protogen/c_unmarshal.py
+++ b/lib9p/protogen/c_unmarshal.py
@@ -3,6 +3,7 @@
# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
+import typing
import idl
@@ -21,72 +22,112 @@ def gen_c_unmarshal(versions: set[str], typs: list[idl.UserType]) -> str:
ret = """
/* unmarshal_* ****************************************************************/
-LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
-\t*out = ctx->net_bytes[ctx->net_offset];
-\tctx->net_offset += 1;
-}
-
-LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
-\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]);
-\tctx->net_offset += 2;
-}
-
-LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
-\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]);
-\tctx->net_offset += 4;
-}
-
-LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
-\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]);
-\tctx->net_offset += 8;
-}
"""
- for typ in idlutil.topo_sorted(typs):
- inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
- argfn = (
- c9util.arg_unused
- if (isinstance(typ, idl.Struct) and not typ.members)
- else c9util.arg_used
- )
+ ret += cutil.macro(
+ "#define UNMARSHAL_BYTES(ctx, data_lvalue, len)\n"
+ "\tdata_lvalue = (char *)&net_bytes[net_offset];\n"
+ "\tnet_offset += len;\n"
+ )
+ ret += cutil.macro(
+ "#define UNMARSHAL_U8LE(ctx, val_lvalue)\n"
+ "\tval_lvalue = net_bytes[net_offset];\n"
+ "\tnet_offset += 1;\n"
+ )
+ ret += cutil.macro(
+ "#define UNMARSHAL_U16LE(ctx, val_lvalue)\n"
+ "\tval_lvalue = uint16le_decode(&net_bytes[net_offset]);\n"
+ "\tnet_offset += 2;\n"
+ )
+ ret += cutil.macro(
+ "#define UNMARSHAL_U32LE(ctx, val_lvalue)\n"
+ "\tval_lvalue = uint32le_decode(&net_bytes[net_offset]);\n"
+ "\tnet_offset += 4;\n"
+ )
+ ret += cutil.macro(
+ "#define UNMARSHAL_U64LE(ctx, val_lvalue)\n"
+ "\tval_lvalue = uint64le_decode(&net_bytes[net_offset]);\n"
+ "\tnet_offset += 8;\n"
+ )
+
+ class IndentLevel(typing.NamedTuple):
+ ifdef: bool # whether this is both `{` and `#if`, or just `{`
+
+ indent_stack: list[IndentLevel]
+
+ def ifdef_lvl() -> int:
+ return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
+
+ def indent_lvl() -> int:
+ return len(indent_stack)
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal indent_stack
+ indent_stack_len = len(indent_stack)
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal indent_stack_len
+ while len(indent_stack) > indent_stack_len:
+ if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
+ break
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ if not path.elems:
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ if line := cutil.ifdef_push(
+ ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
+ ):
+ ret += line
+ ret += (
+ f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+ )
+ indent_stack.append(IndentLevel(ifdef=True))
+ if child.cnt:
+ cnt_path = path.parent().add(child.cnt)
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret += f"{'\t'*indent_lvl()}UNMARSHAL_BYTES(ctx, {path.c_str('out->')[:-3]}, {cnt_path.c_str('out->')});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+ ret += f"{'\t'*indent_lvl()}{path.c_str('out->')[:-3]} = extra;\n"
+ ret += f"{'\t'*indent_lvl()}extra += sizeof({path.c_str('out->')[:-3]}[0]) * {cnt_path.c_str('out->')};\n"
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ loopvar = chr(ord("i") + loopdepth - 1)
+ ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('out->')}; {loopvar}++) {{\n"
+ indent_stack.append(IndentLevel(ifdef=False))
+ if not isinstance(child.typ, idl.Struct):
+ if child.val:
+ ret += f"{'\t'*indent_lvl()}net_offset += {child.typ.static_size};\n"
+ else:
+ ret += f"{'\t'*indent_lvl()}UNMARSHAL_U{child.typ.static_size*8}LE(ctx, {path.c_str('out->')});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ for typ in typs:
+ if not (
+ isinstance(typ, idl.Message) or typ.typname == "stat"
+ ): # SPECIAL (include stat)
+ continue
+ assert isinstance(typ, idl.Struct)
ret += "\n"
ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
- ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n"
- match typ:
- case idl.Number():
- ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n"
- case idl.Bitfield():
- ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n"
- case idl.Struct():
- ret += "\tmemset(out, 0, sizeof(*out));\n"
-
- for member in typ.members:
- ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
- if member.val:
- ret += f"\tctx->net_offset += {member.static_size};\n"
- continue
- ret += "\t"
-
- prefix = "\t"
- if member.in_versions != typ.in_versions:
- ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) "
- prefix = "\t\t"
- if member.cnt:
- if member.in_versions != typ.in_versions:
- ret += "{\n"
- ret += prefix
- if member.typ.static_size == 1: # SPECIAL (string, zerocopy)
- ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n"
- ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n"
- else:
- ret += f"out->{member.membname} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n"
- ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n"
- if member.in_versions != typ.in_versions:
- ret += "\t}\n"
- else:
- ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n"
- ret += cutil.ifdef_pop(1)
- ret += "}\n"
+ ret += f"static void unmarshal_{typ.typname}([[gnu::unused]] struct lib9p_ctx *ctx, uint8_t *net_bytes, void *out_buf) {{\n"
+ ret += f"\t{c9util.typename(typ)} *out = out_buf;\n"
+ ret += "\t[[gnu::unused]] void *extra = &out[1];\n"
+ ret += "\tuint32_t net_offset = 0;\n"
+
+ indent_stack = [IndentLevel(ifdef=True)]
+ idlutil.walk(typ, handle)
+ while len(indent_stack) > 0:
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef and indent_stack:
+ ret += cutil.ifdef_pop(ifdef_lvl())
ret += cutil.ifdef_pop(0)
return ret
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py
index a3f4348..e315b60 100644
--- a/lib9p/protogen/c_validate.py
+++ b/lib9p/protogen/c_validate.py
@@ -3,6 +3,7 @@
# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
+import typing
import idl
@@ -17,155 +18,270 @@ from . import c9util, cutil, idlutil
__all__ = ["gen_c_validate"]
-def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
- return bool(member.max or member.val or any(m.cnt == member for m in typ.members))
+def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool:
+ if child.val or child.max or isinstance(child.typ, idl.Bitfield):
+ return True
+ for sibling in parent.members:
+ if sibling.val:
+ for tok in sibling.val.tokens:
+ if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}":
+ return True
+ if sibling.max:
+ for tok in sibling.max.tokens:
+ if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}":
+ return True
+ return False
+
+
+def should_save_end_offset(struct: idl.Struct) -> bool:
+ for memb in struct.members:
+ if memb.val:
+ for tok in memb.val.tokens:
+ if isinstance(tok, idl.ExprSym) and tok.symname == "end":
+ return True
+ if memb.max:
+ for tok in memb.max.tokens:
+ if isinstance(tok, idl.ExprSym) and tok.symname == "end":
+ return True
+ return False
def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
ret = """
/* validate_* *****************************************************************/
-LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
-\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
-\t\t/* If needed-net-size overflowed uint32_t, then
-\t\t * there's no way that actual-net-size will live up to
-\t\t * that. */
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\tif (ctx->net_offset > ctx->net_size)
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
-\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
-\t\t/* If needed-host-size overflowed size_t, then there's
-\t\t * no way that actual-net-size will live up to
-\t\t * that. */
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
- size_t cnt,
- _validate_fn_t item_fn, size_t item_host_size) {
-\tfor (size_t i = 0; i < cnt; i++)
-\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
-\t\t\treturn true;
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
-LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
-LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
-LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""
+ ret += cutil.macro(
+ "#define VALIDATE_NET_BYTES(n)\n"
+ "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n"
+ "\t\t/* If needed-net-size overflowed uint32_t, then\n"
+ "\t\t * there's no way that actual-net-size will live up to\n"
+ "\t\t * that. */\n"
+ '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
+ "\tif (net_offset > net_size)\n"
+ '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n'
+ )
+ ret += cutil.macro(
+ "#define VALIDATE_NET_UTF8(n)\n"
+ "\t{\n"
+ "\t\tsize_t len = n;\n"
+ "\t\tVALIDATE_NET_BYTES(len);\n"
+ "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n"
+ '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
+ "\t}\n"
+ )
+ ret += cutil.macro(
+ "#define RESERVE_HOST_BYTES(n)\n"
+ "\tif (__builtin_add_overflow(host_size, n, &host_size))\n"
+ "\t\t/* If needed-host-size overflowed ssize_t, then there's\n"
+ "\t\t * no way that actual-net-size will live up to\n"
+ "\t\t * that. */\n"
+ '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
+ )
+
+ ret += "#define GET_U8LE(off) (net_bytes[off])\n"
+ ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n"
+ ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n"
+ ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n"
+
+ ret += "#define LAST_U8LE() GET_U8LE(net_offset-1)\n"
+ ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n"
+ ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n"
+ ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n"
+
+ class IndentLevel(typing.NamedTuple):
+ ifdef: bool # whether this is both `{` and `#if`, or just `{`
+
+ indent_stack: list[IndentLevel]
+
+ def ifdef_lvl() -> int:
+ return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
+
+ def indent_lvl() -> int:
+ return len(indent_stack)
+
+ incr_buf: int
+
+ def incr_flush() -> None:
+ nonlocal ret
+ nonlocal incr_buf
+ if incr_buf:
+ ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n"
+ incr_buf = 0
+
+ def gen_validate_size(path: idlutil.Path) -> None:
+ nonlocal ret
+ nonlocal incr_buf
+ nonlocal indent_stack
+
+ assert path.elems
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ assert isinstance(parent, idl.Struct)
+
+ if child.in_versions < parent.in_versions:
+ if line := cutil.ifdef_push(
+ ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
+ ):
+ incr_flush()
+ ret += line
+ ret += (
+ f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+ )
+ indent_stack.append(IndentLevel(ifdef=True))
+ if should_save_offset(parent, child):
+ ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n"
+ if child.cnt:
+ assert child.cnt.typ.static_size
+ cnt_path = path.parent().add(child.cnt)
+ incr_flush()
+ if child.membname == "utf8": # SPECIAL (string)
+ # Yes, this is content-validation and "belongs" in
+ # gen_validate_content(), not here. But it's just
+ # easier this way.
+ ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8(LAST_U{child.cnt.typ.static_size*8}LE());\n"
+ return
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES(LAST_U{child.cnt.typ.static_size*8}LE());\n"
+ return
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ loopvar = chr(ord("i") + loopdepth - 1)
+ ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0, cnt = LAST_U{child.cnt.typ.static_size*8}LE(); {loopvar} < cnt; {loopvar}++) {{\n"
+ indent_stack.append(IndentLevel(ifdef=False))
+ ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n"
+ if not isinstance(child.typ, idl.Struct):
+ incr_buf += child.typ.static_size
- for typ in idlutil.topo_sorted(typs):
- inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
- argfn = (
- c9util.arg_unused
- if (isinstance(typ, idl.Struct) and not typ.members)
- else c9util.arg_used
- )
+ def gen_validate_content(path: idlutil.Path) -> None:
+ nonlocal ret
+ nonlocal incr_buf
+ nonlocal indent_stack
+
+ assert path.elems
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ assert isinstance(parent, idl.Struct)
+
+ def lookup_sym(sym: str) -> str:
+ if sym.startswith("&"):
+ sym = sym[1:]
+ return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
+
+ if child.val:
+ incr_flush()
+ assert child.typ.static_size
+ nbits = child.typ.static_size * 8
+ nbits = child.typ.static_size * 8
+ if nbits < 32 and any(
+ isinstance(tok, idl.ExprSym)
+ and (tok.symname == "end" or tok.symname.startswith("&"))
+ for tok in child.val.tokens
+ ):
+ nbits = 32
+ act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+ exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})"
+ ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n"
+ ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n'
+ ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
+ if child.max:
+ incr_flush()
+ assert child.typ.static_size
+ nbits = child.typ.static_size * 8
+ if nbits < 32 and any(
+ isinstance(tok, idl.ExprSym)
+ and (tok.symname == "end" or tok.symname.startswith("&"))
+ for tok in child.max.tokens
+ ):
+ nbits = 32
+ act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+ exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})"
+ ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n"
+ ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n'
+ ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
+ if isinstance(child.typ, idl.Bitfield):
+ incr_flush()
+ nbytes = child.typ.static_size
+ nbits = nbytes * 8
+ act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+ ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n"
+ ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n'
+ ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n"
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal incr_buf
+ nonlocal indent_stack
+ indent_stack_len = len(indent_stack)
+ pop_struct = path.elems[-1].typ if path.elems else path.root
+ pop_path = path
+ pop_indent_stack_len: int
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal indent_stack_len
+ nonlocal pop_struct
+ nonlocal pop_path
+ nonlocal pop_indent_stack_len
+ if isinstance(pop_struct, idl.Struct):
+ while len(indent_stack) > pop_indent_stack_len:
+ incr_flush()
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+ parent = pop_struct
+ path = pop_path
+ if should_save_end_offset(parent):
+ ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n"
+ for child in parent.members:
+ gen_validate_content(pop_path.add(child))
+ while len(indent_stack) > indent_stack_len:
+ if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
+ break
+ incr_flush()
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ if path.elems:
+ gen_validate_size(path)
+
+ pop_indent_stack_len = len(indent_stack)
+
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ for typ in typs:
+ if not (
+ isinstance(typ, idl.Message) or typ.typname == "stat"
+ ): # SPECIAL (include stat)
+ continue
+ assert isinstance(typ, idl.Struct)
ret += "\n"
ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
- ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
-
- match typ:
- case idl.Number():
- ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
- case idl.Bitfield():
- ret += f"\t if (validate_{typ.static_size}(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
- if typ.static_size == 1:
- ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
- else:
- ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
- ret += "\tif (val & ~mask)\n"
- ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
- ret += "\treturn false;\n"
- case idl.Struct(): # and idl.Message()
- if len(typ.members) == 0:
- ret += "\treturn false;\n"
- ret += "}\n"
- continue
-
- # Pass 1 - declare value variables
- for member in typ.members:
- if should_save_value(typ, member):
- ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
- ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
- ret += cutil.ifdef_pop(1)
-
- # Pass 2 - declare offset variables
- mark_offset: set[str] = set()
- for member in typ.members:
- for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
- if tok.symname[1:] not in mark_offset:
- ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
- mark_offset.add(tok.symname[1:])
-
- # Pass 3 - main pass
- ret += "\treturn false\n"
- for member in typ.members:
- ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
- ret += "\t || "
- if member.in_versions != typ.in_versions:
- ret += "( " + c9util.ver_cond(member.in_versions) + " && "
- if member.cnt is not None:
- if member.typ.static_size == 1: # SPECIAL (zerocopy)
- ret += f"_validate_size_net(ctx, {member.cnt.membname})"
- else:
- ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
- if typ.typname == "s": # SPECIAL (string)
- ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
- else:
- if should_save_value(typ, member):
- ret += "("
- if member.membname in mark_offset:
- ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
- ret += f"validate_{member.typ.typname}(ctx)"
- if member.membname in mark_offset:
- ret += "; })"
- if should_save_value(typ, member):
- nbytes = member.static_size
- assert nbytes
- if nbytes == 1:
- ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
- else:
- ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
- if member.in_versions != typ.in_versions:
- ret += " )"
- ret += "\n"
-
- # Pass 4 - validate ,max= and ,val= constraints
- for member in typ.members:
-
- def lookup_sym(sym: str) -> str:
- match sym:
- case "end":
- return "ctx->net_offset"
- case _:
- assert sym.startswith("&")
- return f"_{sym[1:]}_offset"
-
- if member.max:
- assert member.static_size
- nbits = member.static_size * 8
- ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
- if member.val:
- assert member.static_size
- nbits = member.static_size * 8
- ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
-
- ret += cutil.ifdef_pop(1)
- ret += "\t ;\n"
+ if typ.typname == "stat": # SPECIAL (stat)
+ ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n"
+ else:
+ ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n"
+
+ ret += "\tuint32_t net_offset = 0;\n"
+ ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n"
+
+ incr_buf = 0
+ indent_stack = [IndentLevel(ifdef=True)]
+ idlutil.walk(typ, handle)
+ while len(indent_stack) > 1:
+ incr_flush()
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ incr_flush()
+ if typ.typname == "stat": # SPECIAL (stat)
+ ret += "\tif (ret_net_size)\n"
+ ret += "\t\t*ret_net_size = net_offset;\n"
+ ret += "\treturn (ssize_t)host_size;\n"
ret += "}\n"
ret += cutil.ifdef_pop(0)
return ret
diff --git a/lib9p/protogen/cutil.py b/lib9p/protogen/cutil.py
index a78cd17..8df6db9 100644
--- a/lib9p/protogen/cutil.py
+++ b/lib9p/protogen/cutil.py
@@ -31,7 +31,7 @@ def macro(full: str) -> str:
lines = [l.rstrip() for l in full.split("\n")]
width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1])
lines = [tab_ljust(l, width) for l in lines]
- return " \\\n".join(lines).rstrip() + "\n"
+ return " \\\n".join(lines).rstrip() + "\n"
_ifdef_stack: list[str | None] = []