summaryrefslogtreecommitdiff
path: root/lib9p/protogen/c_validate.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/protogen/c_validate.py')
-rw-r--r--lib9p/protogen/c_validate.py299
1 files changed, 299 insertions, 0 deletions
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py
new file mode 100644
index 0000000..535a750
--- /dev/null
+++ b/lib9p/protogen/c_validate.py
@@ -0,0 +1,299 @@
+# lib9p/protogen/c_validate.py - Generate C validation functions
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import typing
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c_validate"]
+
+
+def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool:
+ if child.val or child.max or isinstance(child.typ, idl.Bitfield):
+ return True
+ for sibling in parent.members:
+ if sibling.val:
+ for tok in sibling.val.tokens:
+ if isinstance(tok, idl.ExprOff) and tok.membname == child.membname:
+ return True
+ if sibling.max:
+ for tok in sibling.max.tokens:
+ if isinstance(tok, idl.ExprOff) and tok.membname == child.membname:
+ return True
+ return False
+
+
+def should_save_end_offset(struct: idl.Struct) -> bool:
+ for memb in struct.members:
+ if memb.val:
+ for tok in memb.val.tokens:
+ if isinstance(tok, idl.ExprSym) and tok.symname == "end":
+ return True
+ if memb.max:
+ for tok in memb.max.tokens:
+ if isinstance(tok, idl.ExprSym) and tok.symname == "end":
+ return True
+ return False
+
+
+def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
+ ret = """
+/* validate_* *****************************************************************/
+
+"""
+ ret += cutil.macro(
+ "#define VALIDATE_NET_BYTES(n)\n"
+ "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n"
+ "\t\t/* If needed-net-size overflowed uint32_t, then\n"
+ "\t\t * there's no way that actual-net-size will live up to\n"
+ "\t\t * that. */\n"
+ '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
+ "\tif (net_offset > net_size)\n"
+ '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n'
+ )
+ ret += cutil.macro(
+ "#define VALIDATE_NET_UTF8(n)\n"
+ "\t{\n"
+ "\t\tsize_t len = n;\n"
+ "\t\tVALIDATE_NET_BYTES(len);\n"
+ "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n"
+ '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
+ "\t}\n"
+ )
+ ret += cutil.macro(
+ "#define RESERVE_HOST_BYTES(n)\n"
+ "\tif (__builtin_add_overflow(host_size, n, &host_size))\n"
+ "\t\t/* If needed-host-size overflowed ssize_t, then there's\n"
+ "\t\t * no way that actual-net-size will live up to\n"
+ "\t\t * that. */\n"
+ '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
+ )
+
+ ret += "#define GET_U8LE(off) (net_bytes[off])\n"
+ ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n"
+ ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n"
+ ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n"
+
+ ret += "#define LAST_U8LE() GET_U8LE(net_offset-1)\n"
+ ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n"
+ ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n"
+ ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n"
+
+ class IndentLevel(typing.NamedTuple):
+ ifdef: bool # whether this is both `{` and `#if`, or just `{`
+
+ indent_stack: list[IndentLevel]
+
+ def ifdef_lvl() -> int:
+ return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
+
+ def indent_lvl() -> int:
+ return len(indent_stack)
+
+ incr_buf: int
+
+ def incr_flush() -> None:
+ nonlocal ret
+ nonlocal incr_buf
+ if incr_buf:
+ ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n"
+ incr_buf = 0
+
+ def gen_validate_size(path: idlutil.Path) -> None:
+ nonlocal ret
+ nonlocal incr_buf
+ nonlocal indent_stack
+
+ assert path.elems
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ assert isinstance(parent, idl.Struct)
+
+ if child.in_versions < parent.in_versions:
+ if line := cutil.ifdef_push(
+ ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
+ ):
+ incr_flush()
+ ret += line
+ ret += (
+ f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+ )
+ indent_stack.append(IndentLevel(ifdef=True))
+ if should_save_offset(parent, child):
+ ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n"
+ if child.cnt:
+ if isinstance(child.cnt, int):
+ cnt_str = str(child.cnt)
+ cnt_typ = "size_t"
+ else:
+ assert child.cnt.typ.static_size
+ incr_flush()
+ cnt_str = f"LAST_U{child.cnt.typ.static_size*8}LE()"
+ cnt_typ = c9util.typename(child.cnt.typ)
+ if child.membname == "utf8": # SPECIAL (string)
+ assert child.typ.static_size == 1
+ # Yes, this is content-validation and "belongs" in
+ # gen_validate_content(), not here. But it's just
+ # easier this way.
+ incr_flush()
+ ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8({cnt_str});\n"
+ return
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ if isinstance(child.cnt, int):
+ incr_buf += child.cnt
+ return
+ incr_flush()
+ ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({cnt_str});\n"
+ return
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ loopvar = chr(ord("i") + loopdepth - 1)
+ incr_flush()
+ ret += f"{'\t'*indent_lvl()}for ({cnt_typ} {loopvar} = 0, cnt = {cnt_str}; {loopvar} < cnt; {loopvar}++) {{\n"
+ indent_stack.append(IndentLevel(ifdef=False))
+ ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n"
+ if not isinstance(child.typ, idl.Struct):
+ incr_buf += child.typ.static_size
+
+ def gen_validate_content(path: idlutil.Path) -> None:
+ nonlocal ret
+ nonlocal incr_buf
+ nonlocal indent_stack
+
+ assert path.elems
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ assert isinstance(parent, idl.Struct)
+
+ def lookup_sym(sym: str) -> str:
+ if sym.startswith("&"):
+ sym = sym[1:]
+ return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
+
+ if child.val:
+ incr_flush()
+ assert child.typ.static_size
+ nbits = child.typ.static_size * 8
+ nbits = child.typ.static_size * 8
+ if nbits < 32 and any(
+ isinstance(tok, idl.ExprSym)
+ and (tok.symname == "end" or tok.symname.startswith("&"))
+ for tok in child.val.tokens
+ ):
+ nbits = 32
+ act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+ exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})"
+ ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n"
+ ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n'
+ ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
+ if child.max:
+ incr_flush()
+ assert child.typ.static_size
+ nbits = child.typ.static_size * 8
+ if nbits < 32 and any(
+ isinstance(tok, idl.ExprSym)
+ and (tok.symname == "end" or tok.symname.startswith("&"))
+ for tok in child.max.tokens
+ ):
+ nbits = 32
+ act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+ exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})"
+ ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n"
+ ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n'
+ ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
+ if isinstance(child.typ, idl.Bitfield):
+ incr_flush()
+ nbytes = child.typ.static_size
+ nbits = nbytes * 8
+ act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+ ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n"
+ ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n'
+ ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n"
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal incr_buf
+ nonlocal indent_stack
+ indent_stack_len = len(indent_stack)
+ pop_struct = path.elems[-1].typ if path.elems else path.root
+ pop_path = path
+ pop_indent_stack_len: int
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal indent_stack
+ nonlocal indent_stack_len
+ nonlocal pop_struct
+ nonlocal pop_path
+ nonlocal pop_indent_stack_len
+ if isinstance(pop_struct, idl.Struct):
+ while len(indent_stack) > pop_indent_stack_len:
+ incr_flush()
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+ parent = pop_struct
+ path = pop_path
+ if should_save_end_offset(parent):
+ ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n"
+ for child in parent.members:
+ gen_validate_content(pop_path.add(child))
+ while len(indent_stack) > indent_stack_len:
+ if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
+ break
+ incr_flush()
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ if path.elems:
+ gen_validate_size(path)
+
+ pop_indent_stack_len = len(indent_stack)
+
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ for typ in typs:
+ if not (
+ isinstance(typ, idl.Message) or typ.typname == "stat"
+ ): # SPECIAL (include stat)
+ continue
+ assert isinstance(typ, idl.Struct)
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ if typ.typname == "stat": # SPECIAL (stat)
+ ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n"
+ else:
+ ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n"
+
+ ret += "\tuint32_t net_offset = 0;\n"
+ ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n"
+
+ incr_buf = 0
+ indent_stack = [IndentLevel(ifdef=True)]
+ idlutil.walk(typ, handle)
+ while len(indent_stack) > 1:
+ incr_flush()
+ ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+ if indent_stack.pop().ifdef:
+ ret += cutil.ifdef_pop(ifdef_lvl())
+
+ incr_flush()
+ if typ.typname == "stat": # SPECIAL (stat)
+ ret += "\tif (ret_net_size)\n"
+ ret += "\t\t*ret_net_size = net_offset;\n"
+ ret += "\treturn (ssize_t)host_size;\n"
+ ret += "}\n"
+ ret += cutil.ifdef_pop(0)
+ return ret