summaryrefslogtreecommitdiff
path: root/lib9p/protogen/c_validate.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/protogen/c_validate.py')
-rw-r--r--lib9p/protogen/c_validate.py171
1 files changed, 171 insertions, 0 deletions
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py
new file mode 100644
index 0000000..a3f4348
--- /dev/null
+++ b/lib9p/protogen/c_validate.py
@@ -0,0 +1,171 @@
+# lib9p/protogen/c_validate.py - Generate C validation functions
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c_validate"]
+
+
+def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
+ return bool(member.max or member.val or any(m.cnt == member for m in typ.members))
+
+
+def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
+ ret = """
+/* validate_* *****************************************************************/
+
+LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
+\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
+\t\t/* If needed-net-size overflowed uint32_t, then
+\t\t * there's no way that actual-net-size will live up to
+\t\t * that. */
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\tif (ctx->net_offset > ctx->net_size)
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\treturn false;
+}
+
+LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
+\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
+\t\t/* If needed-host-size overflowed size_t, then there's
+\t\t * no way that actual-net-size will live up to
+\t\t * that. */
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\treturn false;
+}
+
+LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
+ size_t cnt,
+ _validate_fn_t item_fn, size_t item_host_size) {
+\tfor (size_t i = 0; i < cnt; i++)
+\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
+\t\t\treturn true;
+\treturn false;
+}
+
+LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
+LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
+LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
+LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
+"""
+
+ for typ in idlutil.topo_sorted(typs):
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = (
+ c9util.arg_unused
+ if (isinstance(typ, idl.Struct) and not typ.members)
+ else c9util.arg_used
+ )
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
+
+ match typ:
+ case idl.Number():
+ ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
+ case idl.Bitfield():
+ ret += f"\t if (validate_{typ.static_size}(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
+ if typ.static_size == 1:
+ ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
+ else:
+ ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ ret += "\tif (val & ~mask)\n"
+ ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
+ ret += "\treturn false;\n"
+ case idl.Struct(): # and idl.Message()
+ if len(typ.members) == 0:
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ # Pass 1 - declare value variables
+ for member in typ.members:
+ if should_save_value(typ, member):
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
+ ret += cutil.ifdef_pop(1)
+
+ # Pass 2 - declare offset variables
+ mark_offset: set[str] = set()
+ for member in typ.members:
+ for tok in [*member.max.tokens, *member.val.tokens]:
+ if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
+ if tok.symname[1:] not in mark_offset:
+ ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
+ mark_offset.add(tok.symname[1:])
+
+ # Pass 3 - main pass
+ ret += "\treturn false\n"
+ for member in typ.members:
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += "\t || "
+ if member.in_versions != typ.in_versions:
+ ret += "( " + c9util.ver_cond(member.in_versions) + " && "
+ if member.cnt is not None:
+ if member.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret += f"_validate_size_net(ctx, {member.cnt.membname})"
+ else:
+ ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
+ if typ.typname == "s": # SPECIAL (string)
+ ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
+ else:
+ if should_save_value(typ, member):
+ ret += "("
+ if member.membname in mark_offset:
+ ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
+ ret += f"validate_{member.typ.typname}(ctx)"
+ if member.membname in mark_offset:
+ ret += "; })"
+ if should_save_value(typ, member):
+ nbytes = member.static_size
+ assert nbytes
+ if nbytes == 1:
+ ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
+ else:
+ ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
+ if member.in_versions != typ.in_versions:
+ ret += " )"
+ ret += "\n"
+
+ # Pass 4 - validate ,max= and ,val= constraints
+ for member in typ.members:
+
+ def lookup_sym(sym: str) -> str:
+ match sym:
+ case "end":
+ return "ctx->net_offset"
+ case _:
+ assert sym.startswith("&")
+ return f"_{sym[1:]}_offset"
+
+ if member.max:
+ assert member.static_size
+ nbits = member.static_size * 8
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
+ if member.val:
+ assert member.static_size
+ nbits = member.static_size * 8
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
+
+ ret += cutil.ifdef_pop(1)
+ ret += "\t ;\n"
+ ret += "}\n"
+ ret += cutil.ifdef_pop(0)
+ return ret