diff options
Diffstat (limited to 'lib9p/protogen/c_validate.py')
-rw-r--r-- | lib9p/protogen/c_validate.py | 299 |
1 files changed, 299 insertions, 0 deletions
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py new file mode 100644 index 0000000..535a750 --- /dev/null +++ b/lib9p/protogen/c_validate.py @@ -0,0 +1,299 @@ +# lib9p/protogen/c_validate.py - Generate C validation functions +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import typing + +import idl + +from . import c9util, cutil, idlutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["gen_c_validate"] + + +def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool: + if child.val or child.max or isinstance(child.typ, idl.Bitfield): + return True + for sibling in parent.members: + if sibling.val: + for tok in sibling.val.tokens: + if isinstance(tok, idl.ExprOff) and tok.membname == child.membname: + return True + if sibling.max: + for tok in sibling.max.tokens: + if isinstance(tok, idl.ExprOff) and tok.membname == child.membname: + return True + return False + + +def should_save_end_offset(struct: idl.Struct) -> bool: + for memb in struct.members: + if memb.val: + for tok in memb.val.tokens: + if isinstance(tok, idl.ExprSym) and tok.symname == "end": + return True + if memb.max: + for tok in memb.max.tokens: + if isinstance(tok, idl.ExprSym) and tok.symname == "end": + return True + return False + + +def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str: + ret = """ +/* validate_* *****************************************************************/ + +""" + ret += cutil.macro( + "#define VALIDATE_NET_BYTES(n)\n" + "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n" + "\t\t/* If needed-net-size overflowed uint32_t, then\n" + "\t\t * there's no way that actual-net-size will live up to\n" + "\t\t * that. */\n" + '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n' + "\tif (net_offset > net_size)\n" + '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n' + ) + ret += cutil.macro( + "#define VALIDATE_NET_UTF8(n)\n" + "\t{\n" + "\t\tsize_t len = n;\n" + "\t\tVALIDATE_NET_BYTES(len);\n" + "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n" + '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' + "\t}\n" + ) + ret += cutil.macro( + "#define RESERVE_HOST_BYTES(n)\n" + "\tif (__builtin_add_overflow(host_size, n, &host_size))\n" + "\t\t/* If needed-host-size overflowed ssize_t, then there's\n" + "\t\t * no way that actual-net-size will live up to\n" + "\t\t * that. */\n" + '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n' + ) + + ret += "#define GET_U8LE(off) (net_bytes[off])\n" + ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n" + ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n" + ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n" + + ret += "#define LAST_U8LE() GET_U8LE(net_offset-1)\n" + ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n" + ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n" + ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n" + + class IndentLevel(typing.NamedTuple): + ifdef: bool # whether this is both `{` and `#if`, or just `{` + + indent_stack: list[IndentLevel] + + def ifdef_lvl() -> int: + return sum(1 if lvl.ifdef else 0 for lvl in indent_stack) + + def indent_lvl() -> int: + return len(indent_stack) + + incr_buf: int + + def incr_flush() -> None: + nonlocal ret + nonlocal incr_buf + if incr_buf: + ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n" + incr_buf = 0 + + def gen_validate_size(path: idlutil.Path) -> None: + nonlocal ret + nonlocal incr_buf + nonlocal indent_stack + + assert path.elems + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + assert isinstance(parent, idl.Struct) + + if child.in_versions < parent.in_versions: + if line := cutil.ifdef_push( + ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions) + ): + incr_flush() + ret += line + ret += ( + f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n" + ) + indent_stack.append(IndentLevel(ifdef=True)) + if should_save_offset(parent, child): + ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n" + if child.cnt: + if isinstance(child.cnt, int): + cnt_str = str(child.cnt) + cnt_typ = "size_t" + else: + assert child.cnt.typ.static_size + incr_flush() + cnt_str = f"LAST_U{child.cnt.typ.static_size*8}LE()" + cnt_typ = c9util.typename(child.cnt.typ) + if child.membname == "utf8": # SPECIAL (string) + assert child.typ.static_size == 1 + # Yes, this is content-validation and "belongs" in + # gen_validate_content(), not here. But it's just + # easier this way. + incr_flush() + ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8({cnt_str});\n" + return + if child.typ.static_size == 1: # SPECIAL (zerocopy) + if isinstance(child.cnt, int): + incr_buf += child.cnt + return + incr_flush() + ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({cnt_str});\n" + return + loopdepth = sum(1 for elem in path.elems if elem.cnt) + loopvar = chr(ord("i") + loopdepth - 1) + incr_flush() + ret += f"{'\t'*indent_lvl()}for ({cnt_typ} {loopvar} = 0, cnt = {cnt_str}; {loopvar} < cnt; {loopvar}++) {{\n" + indent_stack.append(IndentLevel(ifdef=False)) + ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n" + if not isinstance(child.typ, idl.Struct): + incr_buf += child.typ.static_size + + def gen_validate_content(path: idlutil.Path) -> None: + nonlocal ret + nonlocal incr_buf + nonlocal indent_stack + + assert path.elems + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + assert isinstance(parent, idl.Struct) + + def lookup_sym(sym: str) -> str: + if sym.startswith("&"): + sym = sym[1:] + return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}" + + if child.val: + incr_flush() + assert child.typ.static_size + nbits = child.typ.static_size * 8 + nbits = child.typ.static_size * 8 + if nbits < 32 and any( + isinstance(tok, idl.ExprSym) + and (tok.symname == "end" or tok.symname.startswith("&")) + for tok in child.val.tokens + ): + nbits = 32 + act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" + exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})" + ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n" + ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n' + ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n" + if child.max: + incr_flush() + assert child.typ.static_size + nbits = child.typ.static_size * 8 + if nbits < 32 and any( + isinstance(tok, idl.ExprSym) + and (tok.symname == "end" or tok.symname.startswith("&")) + for tok in child.max.tokens + ): + nbits = 32 + act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" + exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})" + ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n" + ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n' + ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n" + if isinstance(child.typ, idl.Bitfield): + incr_flush() + nbytes = child.typ.static_size + nbits = nbytes * 8 + act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" + ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n" + ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n' + ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n" + + def handle( + path: idlutil.Path, + ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: + nonlocal ret + nonlocal incr_buf + nonlocal indent_stack + indent_stack_len = len(indent_stack) + pop_struct = path.elems[-1].typ if path.elems else path.root + pop_path = path + pop_indent_stack_len: int + + def pop() -> None: + nonlocal ret + nonlocal indent_stack + nonlocal indent_stack_len + nonlocal pop_struct + nonlocal pop_path + nonlocal pop_indent_stack_len + if isinstance(pop_struct, idl.Struct): + while len(indent_stack) > pop_indent_stack_len: + incr_flush() + ret += f"{'\t'*(indent_lvl()-1)}}}\n" + if indent_stack.pop().ifdef: + ret += cutil.ifdef_pop(ifdef_lvl()) + parent = pop_struct + path = pop_path + if should_save_end_offset(parent): + ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n" + for child in parent.members: + gen_validate_content(pop_path.add(child)) + while len(indent_stack) > indent_stack_len: + if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef: + break + incr_flush() + ret += f"{'\t'*(indent_lvl()-1)}}}\n" + if indent_stack.pop().ifdef: + ret += cutil.ifdef_pop(ifdef_lvl()) + + if path.elems: + gen_validate_size(path) + + pop_indent_stack_len = len(indent_stack) + + return idlutil.WalkCmd.KEEP_GOING, pop + + for typ in typs: + if not ( + isinstance(typ, idl.Message) or typ.typname == "stat" + ): # SPECIAL (include stat) + continue + assert isinstance(typ, idl.Struct) + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + if typ.typname == "stat": # SPECIAL (stat) + ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n" + else: + ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n" + + ret += "\tuint32_t net_offset = 0;\n" + ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n" + + incr_buf = 0 + indent_stack = [IndentLevel(ifdef=True)] + idlutil.walk(typ, handle) + while len(indent_stack) > 1: + incr_flush() + ret += f"{'\t'*(indent_lvl()-1)}}}\n" + if indent_stack.pop().ifdef: + ret += cutil.ifdef_pop(ifdef_lvl()) + + incr_flush() + if typ.typname == "stat": # SPECIAL (stat) + ret += "\tif (ret_net_size)\n" + ret += "\t\t*ret_net_size = net_offset;\n" + ret += "\treturn (ssize_t)host_size;\n" + ret += "}\n" + ret += cutil.ifdef_pop(0) + return ret |