# lib9p/protogen/c_validate.py - Generate C validation functions # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import typing import idl from . import c9util, cutil, idlutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # pylint: disable=unused-variable __all__ = ["gen_c_validate"] def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool: if child.val or child.max or isinstance(child.typ, idl.Bitfield): return True for sibling in parent.members: if sibling.val: for tok in sibling.val.tokens: if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}": return True if sibling.max: for tok in sibling.max.tokens: if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}": return True return False def should_save_end_offset(struct: idl.Struct) -> bool: for memb in struct.members: if memb.val: for tok in memb.val.tokens: if isinstance(tok, idl.ExprSym) and tok.symname == "end": return True if memb.max: for tok in memb.max.tokens: if isinstance(tok, idl.ExprSym) and tok.symname == "end": return True return False def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str: ret = """ /* validate_* *****************************************************************/ """ ret += cutil.macro( "#define VALIDATE_NET_BYTES(n)\n" "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n" "\t\t/* If needed-net-size overflowed uint32_t, then\n" "\t\t * there's no way that actual-net-size will live up to\n" "\t\t * that. */\n" '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n' "\tif (net_offset > net_size)\n" '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n' ) ret += cutil.macro( "#define VALIDATE_NET_UTF8(n)\n" "\t{\n" "\t\tsize_t len = n;\n" "\t\tVALIDATE_NET_BYTES(len);\n" "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n" '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' "\t}\n" ) ret += cutil.macro( "#define RESERVE_HOST_BYTES(n)\n" "\tif (__builtin_add_overflow(host_size, n, &host_size))\n" "\t\t/* If needed-host-size overflowed ssize_t, then there's\n" "\t\t * no way that actual-net-size will live up to\n" "\t\t * that. */\n" '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n' ) ret += "#define GET_U8LE(off) (net_bytes[off])\n" ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n" ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n" ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n" ret += "#define LAST_U8LE() GET_U8LE(net_offset-1)\n" ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n" ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n" ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n" class IndentLevel(typing.NamedTuple): ifdef: bool # whether this is both `{` and `#if`, or just `{` indent_stack: list[IndentLevel] def ifdef_lvl() -> int: return sum(1 if lvl.ifdef else 0 for lvl in indent_stack) def indent_lvl() -> int: return len(indent_stack) incr_buf: int def incr_flush() -> None: nonlocal ret nonlocal incr_buf if incr_buf: ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n" incr_buf = 0 def gen_validate_size(path: idlutil.Path) -> None: nonlocal ret nonlocal incr_buf nonlocal indent_stack assert path.elems child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root assert isinstance(parent, idl.Struct) if child.in_versions < parent.in_versions: if line := cutil.ifdef_push( ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions) ): incr_flush() ret += line ret += ( f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n" ) indent_stack.append(IndentLevel(ifdef=True)) if should_save_offset(parent, child): ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n" if child.cnt: assert child.cnt.typ.static_size cnt_path = path.parent().add(child.cnt) incr_flush() if child.membname == "utf8": # SPECIAL (string) # Yes, this is content-validation and "belongs" in # gen_validate_content(), not here. But it's just # easier this way. ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8(LAST_U{child.cnt.typ.static_size*8}LE());\n" return if child.typ.static_size == 1: # SPECIAL (zerocopy) ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES(LAST_U{child.cnt.typ.static_size*8}LE());\n" return loopdepth = sum(1 for elem in path.elems if elem.cnt) loopvar = chr(ord("i") + loopdepth - 1) ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0, cnt = LAST_U{child.cnt.typ.static_size*8}LE(); {loopvar} < cnt; {loopvar}++) {{\n" indent_stack.append(IndentLevel(ifdef=False)) ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n" if not isinstance(child.typ, idl.Struct): incr_buf += child.typ.static_size def gen_validate_content(path: idlutil.Path) -> None: nonlocal ret nonlocal incr_buf nonlocal indent_stack assert path.elems child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root assert isinstance(parent, idl.Struct) def lookup_sym(sym: str) -> str: if sym.startswith("&"): sym = sym[1:] return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}" if child.val: incr_flush() assert child.typ.static_size nbits = child.typ.static_size * 8 nbits = child.typ.static_size * 8 if nbits < 32 and any( isinstance(tok, idl.ExprSym) and (tok.symname == "end" or tok.symname.startswith("&")) for tok in child.val.tokens ): nbits = 32 act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})" ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n" ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n' ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n" if child.max: incr_flush() assert child.typ.static_size nbits = child.typ.static_size * 8 if nbits < 32 and any( isinstance(tok, idl.ExprSym) and (tok.symname == "end" or tok.symname.startswith("&")) for tok in child.max.tokens ): nbits = 32 act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})" ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n" ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n' ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n" if isinstance(child.typ, idl.Bitfield): incr_flush() nbytes = child.typ.static_size nbits = nbytes * 8 act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})" ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n" ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n' ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n" def handle( path: idlutil.Path, ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: nonlocal ret nonlocal incr_buf nonlocal indent_stack indent_stack_len = len(indent_stack) pop_struct = path.elems[-1].typ if path.elems else path.root pop_path = path pop_indent_stack_len: int def pop() -> None: nonlocal ret nonlocal indent_stack nonlocal indent_stack_len nonlocal pop_struct nonlocal pop_path nonlocal pop_indent_stack_len if isinstance(pop_struct, idl.Struct): while len(indent_stack) > pop_indent_stack_len: incr_flush() ret += f"{'\t'*(indent_lvl()-1)}}}\n" if indent_stack.pop().ifdef: ret += cutil.ifdef_pop(ifdef_lvl()) parent = pop_struct path = pop_path if should_save_end_offset(parent): ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n" for child in parent.members: gen_validate_content(pop_path.add(child)) while len(indent_stack) > indent_stack_len: if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef: break incr_flush() ret += f"{'\t'*(indent_lvl()-1)}}}\n" if indent_stack.pop().ifdef: ret += cutil.ifdef_pop(ifdef_lvl()) if path.elems: gen_validate_size(path) pop_indent_stack_len = len(indent_stack) return idlutil.WalkCmd.KEEP_GOING, pop for typ in typs: if not ( isinstance(typ, idl.Message) or typ.typname == "stat" ): # SPECIAL (include stat) continue assert isinstance(typ, idl.Struct) ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) if typ.typname == "stat": # SPECIAL (stat) ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n" else: ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n" ret += "\tuint32_t net_offset = 0;\n" ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n" incr_buf = 0 indent_stack = [IndentLevel(ifdef=True)] idlutil.walk(typ, handle) while len(indent_stack) > 1: incr_flush() ret += f"{'\t'*(indent_lvl()-1)}}}\n" if indent_stack.pop().ifdef: ret += cutil.ifdef_pop(ifdef_lvl()) incr_flush() if typ.typname == "stat": # SPECIAL (stat) ret += "\tif (ret_net_size)\n" ret += "\t\t*ret_net_size = net_offset;\n" ret += "\treturn (ssize_t)host_size;\n" ret += "}\n" ret += cutil.ifdef_pop(0) return ret