diff options
Diffstat (limited to 'lib9p/protogen/c_validate.py')
-rw-r--r-- | lib9p/protogen/c_validate.py | 299 |
1 files changed, 0 insertions, 299 deletions
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py deleted file mode 100644 index 535a750..0000000 --- a/lib9p/protogen/c_validate.py +++ /dev/null @@ -1,299 +0,0 @@ -# lib9p/protogen/c_validate.py - Generate C validation functions -# -# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> -# SPDX-License-Identifier: AGPL-3.0-or-later - -import typing - -import idl - -from . import c9util, cutil, idlutil - -# This strives to be "general-purpose" in that it just acts on the -# *.9p inputs; but (unfortunately?) there are a few special-cases in -# this script, marked with "SPECIAL". - - -# pylint: disable=unused-variable -__all__ = ["gen_c_validate"] - - -def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool: - if child.val or child.max or isinstance(child.typ, idl.Bitfield): - return True - for sibling in parent.members: - if sibling.val: - for tok in sibling.val.tokens: - if isinstance(tok, idl.ExprOff) and tok.membname == child.membname: - return True - if sibling.max: - for tok in sibling.max.tokens: - if isinstance(tok, idl.ExprOff) and tok.membname == child.membname: - return True - return False - - -def should_save_end_offset(struct: idl.Struct) -> bool: - for memb in struct.members: - if memb.val: - for tok in memb.val.tokens: - if isinstance(tok, idl.ExprSym) and tok.symname == "end": - return True - if memb.max: - for tok in memb.max.tokens: - if isinstance(tok, idl.ExprSym) and tok.symname == "end": - return True - return False - - -def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str: - ret = """ -/* validate_* *****************************************************************/ - -""" - ret += cutil.macro( - "#define VALIDATE_NET_BYTES(n)\n" - "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n" - "\t\t/* If needed-net-size overflowed uint32_t, then\n" - "\t\t * there's no way that actual-net-size will live up to\n" - "\t\t * that. */\n" - '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n' - "\tif (net_offset > net_size)\n" - '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n' - ) - ret += cutil.macro( - "#define VALIDATE_NET_UTF8(n)\n" - "\t{\n" - "\t\tsize_t len = n;\n" - "\t\tVALIDATE_NET_BYTES(len);\n" - "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n" - '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' - "\t}\n" - ) - ret += cutil.macro( - "#define RESERVE_HOST_BYTES(n)\n" - "\tif (__builtin_add_overflow(host_size, n, &host_size))\n" - "\t\t/* If needed-host-size overflowed ssize_t, then there's\n" - "\t\t * no way that actual-net-size will live up to\n" - "\t\t * that. */\n" - '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n' - ) - - ret += "#define GET_U8LE(off) (net_bytes[off])\n" - ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n" - ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n" - ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n" - - ret += "#define LAST_U8LE() GET_U8LE(net_offset-1)\n" - ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n" - ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n" - ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n" - - class IndentLevel(typing.NamedTuple): - ifdef: bool # whether this is both `{` and `#if`, or just `{` - - indent_stack: list[IndentLevel] - - def ifdef_lvl() -> int: - return sum(1 if lvl.ifdef else 0 for lvl in indent_stack) - - def indent_lvl() -> int: - return len(indent_stack) - - incr_buf: int - - def incr_flush() -> None: - nonlocal ret - nonlocal incr_buf - if incr_buf: - ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n" - incr_buf = 0 - - def gen_validate_size(path: idlutil.Path) -> None: - nonlocal ret - nonlocal incr_buf - nonlocal indent_stack - - assert path.elems - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - assert isinstance(parent, idl.Struct) - - if child.in_versions < parent.in_versions: - if line := cutil.ifdef_push( - ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions) - ): - incr_flush() - ret += line - ret += ( - f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n" - ) - indent_stack.append(IndentLevel(ifdef=True)) - if should_save_offset(parent, child): - ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n" - if child.cnt: - if isinstance(child.cnt, int): - cnt_str = str(child.cnt) - cnt_typ = "size_t" - else: - assert child.cnt.typ.static_size - incr_flush() - cnt_str = f"LAST_U{child.cnt.typ.static_size*8}LE()" - cnt_typ = c9util.typename(child.cnt.typ) - if child.membname == "utf8": # SPECIAL (string) - assert child.typ.static_size == 1 - # Yes, this is content-validation and "belongs" in - # gen_validate_content(), not here. But it's just - # easier this way. - incr_flush() - ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8({cnt_str});\n" - return - if child.typ.static_size == 1: # SPECIAL (zerocopy) - if isinstance(child.cnt, int): - incr_buf += child.cnt - return - incr_flush() - ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({cnt_str});\n" - return - loopdepth = sum(1 for elem in path.elems if elem.cnt) - loopvar = chr(ord("i") + loopdepth - 1) - incr_flush() - ret += f"{'\t'*indent_lvl()}for ({cnt_typ} {loopvar} = 0, cnt = {cnt_str}; {loopvar} < cnt; {loopvar}++) {{\n" - indent_stack.append(IndentLevel(ifdef=False)) - ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n" - if not isinstance(child.typ, idl.Struct): - incr_buf += child.typ.static_size - - def gen_validate_content(path: idlutil.Path) -> None: - nonlocal ret - nonlocal incr_buf - nonlocal indent_stack - - assert path.elems - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - assert isinstance(parent, idl.Struct) - - def lookup_sym(sym: str) -> str: - if sym.startswith("&"): - sym = sym[1:] - return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}" - - if child.val: - incr_flush() - assert child.typ.static_size - nbits = child.typ.static_size * 8 - nbits = child.typ.static_size * 8 - if nbits < 32 and any( - isinstance(tok, idl.ExprSym) - and (tok.symname == "end" or tok.symname.startswith("&")) - for tok in child.val.tokens - ): - nbits = 32 - act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" - exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})" - ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n" - ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n' - ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n" - if child.max: - incr_flush() - assert child.typ.static_size - nbits = child.typ.static_size * 8 - if nbits < 32 and any( - isinstance(tok, idl.ExprSym) - and (tok.symname == "end" or tok.symname.startswith("&")) - for tok in child.max.tokens - ): - nbits = 32 - act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" - exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})" - ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n" - ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n' - ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n" - if isinstance(child.typ, idl.Bitfield): - incr_flush() - nbytes = child.typ.static_size - nbits = nbytes * 8 - act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" - ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n" - ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n' - ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n" - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: - nonlocal ret - nonlocal incr_buf - nonlocal indent_stack - indent_stack_len = len(indent_stack) - pop_struct = path.elems[-1].typ if path.elems else path.root - pop_path = path - pop_indent_stack_len: int - - def pop() -> None: - nonlocal ret - nonlocal indent_stack - nonlocal indent_stack_len - nonlocal pop_struct - nonlocal pop_path - nonlocal pop_indent_stack_len - if isinstance(pop_struct, idl.Struct): - while len(indent_stack) > pop_indent_stack_len: - incr_flush() - ret += f"{'\t'*(indent_lvl()-1)}}}\n" - if indent_stack.pop().ifdef: - ret += cutil.ifdef_pop(ifdef_lvl()) - parent = pop_struct - path = pop_path - if should_save_end_offset(parent): - ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n" - for child in parent.members: - gen_validate_content(pop_path.add(child)) - while len(indent_stack) > indent_stack_len: - if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef: - break - incr_flush() - ret += f"{'\t'*(indent_lvl()-1)}}}\n" - if indent_stack.pop().ifdef: - ret += cutil.ifdef_pop(ifdef_lvl()) - - if path.elems: - gen_validate_size(path) - - pop_indent_stack_len = len(indent_stack) - - return idlutil.WalkCmd.KEEP_GOING, pop - - for typ in typs: - if not ( - isinstance(typ, idl.Message) or typ.typname == "stat" - ): # SPECIAL (include stat) - continue - assert isinstance(typ, idl.Struct) - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - if typ.typname == "stat": # SPECIAL (stat) - ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n" - else: - ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n" - - ret += "\tuint32_t net_offset = 0;\n" - ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n" - - incr_buf = 0 - indent_stack = [IndentLevel(ifdef=True)] - idlutil.walk(typ, handle) - while len(indent_stack) > 1: - incr_flush() - ret += f"{'\t'*(indent_lvl()-1)}}}\n" - if indent_stack.pop().ifdef: - ret += cutil.ifdef_pop(ifdef_lvl()) - - incr_flush() - if typ.typname == "stat": # SPECIAL (stat) - ret += "\tif (ret_net_size)\n" - ret += "\t\t*ret_net_size = net_offset;\n" - ret += "\treturn (ssize_t)host_size;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - return ret |