From bb14e1f16b5d42934e55c7a53117cf50e1153188 Mon Sep 17 00:00:00 2001
From: "Luke T. Shumaker" <lukeshu@lukeshu.com>
Date: Tue, 25 Mar 2025 00:15:43 -0600
Subject: lib9p: protogen: flatten the validate functions, same as (un)marshal

---
 lib9p/protogen/c_validate.py | 396 ++++++++++++++++++++++++++++---------------
 1 file changed, 257 insertions(+), 139 deletions(-)

(limited to 'lib9p/protogen/c_validate.py')

diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py
index a3f4348..1630af2 100644
--- a/lib9p/protogen/c_validate.py
+++ b/lib9p/protogen/c_validate.py
@@ -3,6 +3,7 @@
 # Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
 # SPDX-License-Identifier: AGPL-3.0-or-later
 
+import typing
 
 import idl
 
@@ -17,155 +18,272 @@ from . import c9util, cutil, idlutil
 __all__ = ["gen_c_validate"]
 
 
-def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
-    return bool(member.max or member.val or any(m.cnt == member for m in typ.members))
+def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool:
+    if child.val or child.max or isinstance(child.typ, idl.Bitfield):
+        return True
+    for sibling in parent.members:
+        if sibling.val:
+            for tok in sibling.val.tokens:
+                if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}":
+                    return True
+        if sibling.max:
+            for tok in sibling.max.tokens:
+                if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}":
+                    return True
+    return False
+
+
+def should_save_end_offset(struct: idl.Struct) -> bool:
+    for memb in struct.members:
+        if memb.val:
+            for tok in memb.val.tokens:
+                if isinstance(tok, idl.ExprSym) and tok.symname == "end":
+                    return True
+        if memb.max:
+            for tok in memb.max.tokens:
+                if isinstance(tok, idl.ExprSym) and tok.symname == "end":
+                    return True
+    return False
 
 
 def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
     ret = """
 /* validate_* *****************************************************************/
 
-LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
-\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
-\t\t/* If needed-net-size overflowed uint32_t, then
-\t\t * there's no way that actual-net-size will live up to
-\t\t * that.  */
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\tif (ctx->net_offset > ctx->net_size)
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
-\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
-\t\t/* If needed-host-size overflowed size_t, then there's
-\t\t * no way that actual-net-size will live up to
-\t\t * that.  */
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
-                                         size_t cnt,
-                                         _validate_fn_t item_fn, size_t item_host_size) {
-\tfor (size_t i = 0; i < cnt; i++)
-\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
-\t\t\treturn true;
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
-LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
-LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
-LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
 """
+    ret += cutil.macro(
+        "#define VALIDATE_NET_BYTES(n)\n"
+        "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n"
+        "\t\t/* If needed-net-size overflowed uint32_t, then\n"
+        "\t\t * there's no way that actual-net-size will live up to\n"
+        "\t\t * that.  */\n"
+        '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");\n'
+        "\tif (net_offset > ctx->net_size)\n"
+        '\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, ctx->net_size, __LINE__);\n'
+    )
+    ret += cutil.macro(
+        "#define VALIDATE_NET_UTF8(n)\n"
+        "\t{\n"
+        "\t\tsize_t len = n;\n"
+        "\t\tVALIDATE_NET_BYTES(len);\n"
+        "\t\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[net_offset-len], len))\n"
+        '\t\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
+        "\t}\n"
+    )
+    ret += cutil.macro(
+        "#define RESERVE_HOST_BYTES(n)\n"
+        "\tif (__builtin_add_overflow(host_size, n, &host_size))\n"
+        "\t\t/* If needed-host-size overflowed ssize_t, then there's\n"
+        "\t\t * no way that actual-net-size will live up to\n"
+        "\t\t * that.  */\n"
+        '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");\n'
+    )
+
+    ret += "#define GET_U8LE(off)  (ctx->net_bytes[off])\n"
+    ret += "#define GET_U16LE(off) uint16le_decode(&ctx->net_bytes[off])\n"
+    ret += "#define GET_U32LE(off) uint32le_decode(&ctx->net_bytes[off])\n"
+    ret += "#define GET_U64LE(off) uint64le_decode(&ctx->net_bytes[off])\n"
+
+    ret += "#define LAST_U8LE()  GET_U8LE(net_offset-1)\n"
+    ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n"
+    ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n"
+    ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n"
+
+    class IndentLevel(typing.NamedTuple):
+        ifdef: bool  # whether this is both `{` and `#if`, or just `{`
+
+    indent_stack: list[IndentLevel]
+
+    def ifdef_lvl() -> int:
+        return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
+
+    def indent_lvl() -> int:
+        return len(indent_stack)
+
+    incr_buf: int
+
+    def incr_flush() -> None:
+        nonlocal ret
+        nonlocal incr_buf
+        if incr_buf:
+            ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n"
+            incr_buf = 0
+
+    def gen_validate_size(path: idlutil.Path) -> None:
+        nonlocal ret
+        nonlocal incr_buf
+        nonlocal indent_stack
+
+        assert path.elems
+        child = path.elems[-1]
+        parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+        assert isinstance(parent, idl.Struct)
+
+        if child.in_versions < parent.in_versions:
+            if line := cutil.ifdef_push(
+                ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
+            ):
+                incr_flush()
+                ret += line
+                ret += (
+                    f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+                )
+                indent_stack.append(IndentLevel(ifdef=True))
+        if should_save_offset(parent, child):
+            ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n"
+        if child.cnt:
+            assert child.cnt.typ.static_size
+            cnt_path = path.parent().add(child.cnt)
+            incr_flush()
+            if child.membname == "utf8":  # SPECIAL (string)
+                # Yes, this is content-validation and "belongs" in
+                # gen_validate_content(), not here.  But it's just
+                # easier this way.
+                ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8(LAST_U{child.cnt.typ.static_size*8}LE());\n"
+                return
+            if child.typ.static_size == 1:  # SPECIAL (zerocopy)
+                ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES(LAST_U{child.cnt.typ.static_size*8}LE());\n"
+                return
+            loopdepth = sum(1 for elem in path.elems if elem.cnt)
+            loopvar = chr(ord("i") + loopdepth - 1)
+            ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0, cnt = LAST_U{child.cnt.typ.static_size*8}LE(); {loopvar} < cnt; {loopvar}++) {{\n"
+            indent_stack.append(IndentLevel(ifdef=False))
+            ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n"
+        if not isinstance(child.typ, idl.Struct):
+            incr_buf += child.typ.static_size
 
-    for typ in idlutil.topo_sorted(typs):
-        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
-        argfn = (
-            c9util.arg_unused
-            if (isinstance(typ, idl.Struct) and not typ.members)
-            else c9util.arg_used
-        )
+    def gen_validate_content(path: idlutil.Path) -> None:
+        nonlocal ret
+        nonlocal incr_buf
+        nonlocal indent_stack
+
+        assert path.elems
+        child = path.elems[-1]
+        parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+        assert isinstance(parent, idl.Struct)
+
+        def lookup_sym(sym: str) -> str:
+            if sym.startswith("&"):
+                sym = sym[1:]
+            return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
+
+        if child.val:
+            incr_flush()
+            assert child.typ.static_size
+            nbits = child.typ.static_size * 8
+            nbits = child.typ.static_size * 8
+            if nbits < 32 and any(
+                isinstance(tok, idl.ExprSym)
+                and (tok.symname == "end" or tok.symname.startswith("&"))
+                for tok in child.val.tokens
+            ):
+                nbits = 32
+            act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+            exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})"
+            ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n"
+            ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n'
+            ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
+        if child.max:
+            incr_flush()
+            assert child.typ.static_size
+            nbits = child.typ.static_size * 8
+            if nbits < 32 and any(
+                isinstance(tok, idl.ExprSym)
+                and (tok.symname == "end" or tok.symname.startswith("&"))
+                for tok in child.max.tokens
+            ):
+                nbits = 32
+            act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+            exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})"
+            ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n"
+            ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n'
+            ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
+        if isinstance(child.typ, idl.Bitfield):
+            incr_flush()
+            nbytes = child.typ.static_size
+            nbits = nbytes * 8
+            act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
+            ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->ctx->version])\n"
+            ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n'
+            ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->ctx->version]);\n"
+
+    def handle(
+        path: idlutil.Path,
+    ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+        nonlocal ret
+        nonlocal incr_buf
+        nonlocal indent_stack
+        indent_stack_len = len(indent_stack)
+        pop_struct = path.elems[-1].typ if path.elems else path.root
+        pop_path = path
+        pop_indent_stack_len: int
+
+        def pop() -> None:
+            nonlocal ret
+            nonlocal indent_stack
+            nonlocal indent_stack_len
+            nonlocal pop_struct
+            nonlocal pop_path
+            nonlocal pop_indent_stack_len
+            if isinstance(pop_struct, idl.Struct):
+                while len(indent_stack) > pop_indent_stack_len:
+                    incr_flush()
+                    ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+                    if indent_stack.pop().ifdef:
+                        ret += cutil.ifdef_pop(ifdef_lvl())
+                parent = pop_struct
+                path = pop_path
+                if should_save_end_offset(parent):
+                    ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n"
+                for child in parent.members:
+                    gen_validate_content(pop_path.add(child))
+            while len(indent_stack) > indent_stack_len:
+                if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
+                    break
+                incr_flush()
+                ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+                if indent_stack.pop().ifdef:
+                    ret += cutil.ifdef_pop(ifdef_lvl())
+
+        if path.elems:
+            gen_validate_size(path)
+
+        pop_indent_stack_len = len(indent_stack)
+
+        return idlutil.WalkCmd.KEEP_GOING, pop
+
+    for typ in typs:
+        if not (
+            isinstance(typ, idl.Message) or typ.typname == "stat"
+        ):  # SPECIAL (include stat)
+            continue
+        assert isinstance(typ, idl.Struct)
         ret += "\n"
         ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
-        ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
-
-        match typ:
-            case idl.Number():
-                ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
-            case idl.Bitfield():
-                ret += f"\t if (validate_{typ.static_size}(ctx))\n"
-                ret += "\t\treturn true;\n"
-                ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
-                if typ.static_size == 1:
-                    ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
-                else:
-                    ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
-                ret += "\tif (val & ~mask)\n"
-                ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
-                ret += "\treturn false;\n"
-            case idl.Struct():  # and idl.Message()
-                if len(typ.members) == 0:
-                    ret += "\treturn false;\n"
-                    ret += "}\n"
-                    continue
-
-                # Pass 1 - declare value variables
-                for member in typ.members:
-                    if should_save_value(typ, member):
-                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
-                        ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
-                ret += cutil.ifdef_pop(1)
-
-                # Pass 2 - declare offset variables
-                mark_offset: set[str] = set()
-                for member in typ.members:
-                    for tok in [*member.max.tokens, *member.val.tokens]:
-                        if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
-                            if tok.symname[1:] not in mark_offset:
-                                ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
-                            mark_offset.add(tok.symname[1:])
-
-                # Pass 3 - main pass
-                ret += "\treturn false\n"
-                for member in typ.members:
-                    ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
-                    ret += "\t    || "
-                    if member.in_versions != typ.in_versions:
-                        ret += "( " + c9util.ver_cond(member.in_versions) + " && "
-                    if member.cnt is not None:
-                        if member.typ.static_size == 1:  # SPECIAL (zerocopy)
-                            ret += f"_validate_size_net(ctx, {member.cnt.membname})"
-                        else:
-                            ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
-                        if typ.typname == "s":  # SPECIAL (string)
-                            ret += '\n\t    || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
-                    else:
-                        if should_save_value(typ, member):
-                            ret += "("
-                        if member.membname in mark_offset:
-                            ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
-                        ret += f"validate_{member.typ.typname}(ctx)"
-                        if member.membname in mark_offset:
-                            ret += "; })"
-                        if should_save_value(typ, member):
-                            nbytes = member.static_size
-                            assert nbytes
-                            if nbytes == 1:
-                                ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
-                            else:
-                                ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
-                    if member.in_versions != typ.in_versions:
-                        ret += " )"
-                    ret += "\n"
-
-                # Pass 4 - validate ,max= and ,val= constraints
-                for member in typ.members:
-
-                    def lookup_sym(sym: str) -> str:
-                        match sym:
-                            case "end":
-                                return "ctx->net_offset"
-                            case _:
-                                assert sym.startswith("&")
-                                return f"_{sym[1:]}_offset"
-
-                    if member.max:
-                        assert member.static_size
-                        nbits = member.static_size * 8
-                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
-                        ret += f"\t    || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
-                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
-                    if member.val:
-                        assert member.static_size
-                        nbits = member.static_size * 8
-                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
-                        ret += f"\t    || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
-                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
-
-                ret += cutil.ifdef_pop(1)
-                ret += "\t    ;\n"
+        if typ.typname == "stat":  # SPECIAL (stat)
+            ret += f"static ssize_t validate_{typ.typname}(struct _validate_ctx *ctx, uint32_t *ret_net_size) {{\n"
+        else:
+            ret += (
+                f"static ssize_t validate_{typ.typname}(struct _validate_ctx *ctx) {{\n"
+            )
+
+        ret += "\tuint32_t net_offset = 0;\n"
+        ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n"
+
+        incr_buf = 0
+        indent_stack = [IndentLevel(ifdef=True)]
+        idlutil.walk(typ, handle)
+        while len(indent_stack) > 1:
+            incr_flush()
+            ret += f"{'\t'*(indent_lvl()-1)}}}\n"
+            if indent_stack.pop().ifdef:
+                ret += cutil.ifdef_pop(ifdef_lvl())
+
+        incr_flush()
+        if typ.typname == "stat":  # SPECIAL (stat)
+            ret += "\tif (ret_net_size)\n"
+            ret += "\t\t*ret_net_size = net_offset;\n"
+        ret += "\treturn (ssize_t)host_size;\n"
         ret += "}\n"
     ret += cutil.ifdef_pop(0)
     return ret
-- 
cgit v1.2.3-2-g168b