summaryrefslogtreecommitdiff
path: root/lib9p/protogen/c_validate.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/protogen/c_validate.py')
-rw-r--r--lib9p/protogen/c_validate.py299
1 files changed, 0 insertions, 299 deletions
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py
deleted file mode 100644
index 535a750..0000000
--- a/lib9p/protogen/c_validate.py
+++ /dev/null
@@ -1,299 +0,0 @@
-# lib9p/protogen/c_validate.py - Generate C validation functions
-#
-# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
-# SPDX-License-Identifier: AGPL-3.0-or-later
-
-import typing
-
-import idl
-
-from . import c9util, cutil, idlutil
-
-# This strives to be "general-purpose" in that it just acts on the
-# *.9p inputs; but (unfortunately?) there are a few special-cases in
-# this script, marked with "SPECIAL".
-
-
-# pylint: disable=unused-variable
-__all__ = ["gen_c_validate"]
-
-
-def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool:
- if child.val or child.max or isinstance(child.typ, idl.Bitfield):
- return True
- for sibling in parent.members:
- if sibling.val:
- for tok in sibling.val.tokens:
- if isinstance(tok, idl.ExprOff) and tok.membname == child.membname:
- return True
- if sibling.max:
- for tok in sibling.max.tokens:
- if isinstance(tok, idl.ExprOff) and tok.membname == child.membname:
- return True
- return False
-
-
-def should_save_end_offset(struct: idl.Struct) -> bool:
- for memb in struct.members:
- if memb.val:
- for tok in memb.val.tokens:
- if isinstance(tok, idl.ExprSym) and tok.symname == "end":
- return True
- if memb.max:
- for tok in memb.max.tokens:
- if isinstance(tok, idl.ExprSym) and tok.symname == "end":
- return True
- return False
-
-
-def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
- ret = """
-/* validate_* *****************************************************************/
-
-"""
- ret += cutil.macro(
- "#define VALIDATE_NET_BYTES(n)\n"
- "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n"
- "\t\t/* If needed-net-size overflowed uint32_t, then\n"
- "\t\t * there's no way that actual-net-size will live up to\n"
- "\t\t * that. */\n"
- '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
- "\tif (net_offset > net_size)\n"
- '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n'
- )
- ret += cutil.macro(
- "#define VALIDATE_NET_UTF8(n)\n"
- "\t{\n"
- "\t\tsize_t len = n;\n"
- "\t\tVALIDATE_NET_BYTES(len);\n"
- "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n"
- '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
- "\t}\n"
- )
- ret += cutil.macro(
- "#define RESERVE_HOST_BYTES(n)\n"
- "\tif (__builtin_add_overflow(host_size, n, &host_size))\n"
- "\t\t/* If needed-host-size overflowed ssize_t, then there's\n"
- "\t\t * no way that actual-net-size will live up to\n"
- "\t\t * that. */\n"
- '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
- )
-
- ret += "#define GET_U8LE(off) (net_bytes[off])\n"
- ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n"
- ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n"
- ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n"
-
- ret += "#define LAST_U8LE() GET_U8LE(net_offset-1)\n"
- ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n"
- ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n"
- ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n"
-
- class IndentLevel(typing.NamedTuple):
- ifdef: bool # whether this is both `{` and `#if`, or just `{`
-
- indent_stack: list[IndentLevel]
-
- def ifdef_lvl() -> int:
- return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
-
- def indent_lvl() -> int:
- return len(indent_stack)
-
- incr_buf: int
-
- def incr_flush() -> None:
- nonlocal ret
- nonlocal incr_buf
- if incr_buf:
- ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n"
- incr_buf = 0
-
- def gen_validate_size(path: idlutil.Path) -> None:
- nonlocal ret
- nonlocal incr_buf
- nonlocal indent_stack
-
- assert path.elems
- child = path.elems[-1]
- parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
- assert isinstance(parent, idl.Struct)
-
- if child.in_versions < parent.in_versions:
- if line := cutil.ifdef_push(
- ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
- ):
- incr_flush()
- ret += line
- ret += (
- f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
- )
- indent_stack.append(IndentLevel(ifdef=True))
- if should_save_offset(parent, child):
- ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n"
- if child.cnt:
- if isinstance(child.cnt, int):
- cnt_str = str(child.cnt)
- cnt_typ = "size_t"
- else:
- assert child.cnt.typ.static_size
- incr_flush()
- cnt_str = f"LAST_U{child.cnt.typ.static_size*8}LE()"
- cnt_typ = c9util.typename(child.cnt.typ)
- if child.membname == "utf8": # SPECIAL (string)
- assert child.typ.static_size == 1
- # Yes, this is content-validation and "belongs" in
- # gen_validate_content(), not here. But it's just
- # easier this way.
- incr_flush()
- ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8({cnt_str});\n"
- return
- if child.typ.static_size == 1: # SPECIAL (zerocopy)
- if isinstance(child.cnt, int):
- incr_buf += child.cnt
- return
- incr_flush()
- ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({cnt_str});\n"
- return
- loopdepth = sum(1 for elem in path.elems if elem.cnt)
- loopvar = chr(ord("i") + loopdepth - 1)
- incr_flush()
- ret += f"{'\t'*indent_lvl()}for ({cnt_typ} {loopvar} = 0, cnt = {cnt_str}; {loopvar} < cnt; {loopvar}++) {{\n"
- indent_stack.append(IndentLevel(ifdef=False))
- ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n"
- if not isinstance(child.typ, idl.Struct):
- incr_buf += child.typ.static_size
-
- def gen_validate_content(path: idlutil.Path) -> None:
- nonlocal ret
- nonlocal incr_buf
- nonlocal indent_stack
-
- assert path.elems
- child = path.elems[-1]
- parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
- assert isinstance(parent, idl.Struct)
-
- def lookup_sym(sym: str) -> str:
- if sym.startswith("&"):
- sym = sym[1:]
- return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
-
- if child.val:
- incr_flush()
- assert child.typ.static_size
- nbits = child.typ.static_size * 8
- nbits = child.typ.static_size * 8
- if nbits < 32 and any(
- isinstance(tok, idl.ExprSym)
- and (tok.symname == "end" or tok.symname.startswith("&"))
- for tok in child.val.tokens
- ):
- nbits = 32
- act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
- exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})"
- ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n"
- ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n'
- ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
- if child.max:
- incr_flush()
- assert child.typ.static_size
- nbits = child.typ.static_size * 8
- if nbits < 32 and any(
- isinstance(tok, idl.ExprSym)
- and (tok.symname == "end" or tok.symname.startswith("&"))
- for tok in child.max.tokens
- ):
- nbits = 32
- act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
- exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})"
- ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n"
- ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n'
- ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
- if isinstance(child.typ, idl.Bitfield):
- incr_flush()
- nbytes = child.typ.static_size
- nbits = nbytes * 8
- act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
- ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n"
- ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n'
- ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n"
-
- def handle(
- path: idlutil.Path,
- ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
- nonlocal ret
- nonlocal incr_buf
- nonlocal indent_stack
- indent_stack_len = len(indent_stack)
- pop_struct = path.elems[-1].typ if path.elems else path.root
- pop_path = path
- pop_indent_stack_len: int
-
- def pop() -> None:
- nonlocal ret
- nonlocal indent_stack
- nonlocal indent_stack_len
- nonlocal pop_struct
- nonlocal pop_path
- nonlocal pop_indent_stack_len
- if isinstance(pop_struct, idl.Struct):
- while len(indent_stack) > pop_indent_stack_len:
- incr_flush()
- ret += f"{'\t'*(indent_lvl()-1)}}}\n"
- if indent_stack.pop().ifdef:
- ret += cutil.ifdef_pop(ifdef_lvl())
- parent = pop_struct
- path = pop_path
- if should_save_end_offset(parent):
- ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n"
- for child in parent.members:
- gen_validate_content(pop_path.add(child))
- while len(indent_stack) > indent_stack_len:
- if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
- break
- incr_flush()
- ret += f"{'\t'*(indent_lvl()-1)}}}\n"
- if indent_stack.pop().ifdef:
- ret += cutil.ifdef_pop(ifdef_lvl())
-
- if path.elems:
- gen_validate_size(path)
-
- pop_indent_stack_len = len(indent_stack)
-
- return idlutil.WalkCmd.KEEP_GOING, pop
-
- for typ in typs:
- if not (
- isinstance(typ, idl.Message) or typ.typname == "stat"
- ): # SPECIAL (include stat)
- continue
- assert isinstance(typ, idl.Struct)
- ret += "\n"
- ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
- if typ.typname == "stat": # SPECIAL (stat)
- ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n"
- else:
- ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n"
-
- ret += "\tuint32_t net_offset = 0;\n"
- ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n"
-
- incr_buf = 0
- indent_stack = [IndentLevel(ifdef=True)]
- idlutil.walk(typ, handle)
- while len(indent_stack) > 1:
- incr_flush()
- ret += f"{'\t'*(indent_lvl()-1)}}}\n"
- if indent_stack.pop().ifdef:
- ret += cutil.ifdef_pop(ifdef_lvl())
-
- incr_flush()
- if typ.typname == "stat": # SPECIAL (stat)
- ret += "\tif (ret_net_size)\n"
- ret += "\t\t*ret_net_size = net_offset;\n"
- ret += "\treturn (ssize_t)host_size;\n"
- ret += "}\n"
- ret += cutil.ifdef_pop(0)
- return ret