diff options
Diffstat (limited to 'lib9p/protogen/c_validate.py')
-rw-r--r-- | lib9p/protogen/c_validate.py | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py new file mode 100644 index 0000000..a3f4348 --- /dev/null +++ b/lib9p/protogen/c_validate.py @@ -0,0 +1,171 @@ +# lib9p/protogen/c_validate.py - Generate C validation functions +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import idl + +from . import c9util, cutil, idlutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["gen_c_validate"] + + +def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: + return bool(member.max or member.val or any(m.cnt == member for m in typ.members)) + + +def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str: + ret = """ +/* validate_* *****************************************************************/ + +LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { +\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) +\t\t/* If needed-net-size overflowed uint32_t, then +\t\t * there's no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\tif (ctx->net_offset > ctx->net_size) +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; +} + +LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { +\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) +\t\t/* If needed-host-size overflowed size_t, then there's +\t\t * no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; +} + +LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, + size_t cnt, + _validate_fn_t item_fn, size_t item_host_size) { +\tfor (size_t i = 0; i < cnt; i++) +\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) +\t\t\treturn true; +\treturn false; +} + +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } +""" + + for typ in idlutil.topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = ( + c9util.arg_unused + if (isinstance(typ, idl.Struct) and not typ.members) + else c9util.arg_used + ) + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" + + match typ: + case idl.Number(): + ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" + case idl.Bitfield(): + ret += f"\t if (validate_{typ.static_size}(ctx))\n" + ret += "\t\treturn true;\n" + ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" + if typ.static_size == 1: + ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + else: + ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += "\tif (val & ~mask)\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' + ret += "\treturn false;\n" + case idl.Struct(): # and idl.Message() + if len(typ.members) == 0: + ret += "\treturn false;\n" + ret += "}\n" + continue + + # Pass 1 - declare value variables + for member in typ.members: + if should_save_value(typ, member): + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" + ret += cutil.ifdef_pop(1) + + # Pass 2 - declare offset variables + mark_offset: set[str] = set() + for member in typ.members: + for tok in [*member.max.tokens, *member.val.tokens]: + if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): + if tok.symname[1:] not in mark_offset: + ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" + mark_offset.add(tok.symname[1:]) + + # Pass 3 - main pass + ret += "\treturn false\n" + for member in typ.members: + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += "\t || " + if member.in_versions != typ.in_versions: + ret += "( " + c9util.ver_cond(member.in_versions) + " && " + if member.cnt is not None: + if member.typ.static_size == 1: # SPECIAL (zerocopy) + ret += f"_validate_size_net(ctx, {member.cnt.membname})" + else: + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" + if typ.typname == "s": # SPECIAL (string) + ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' + else: + if should_save_value(typ, member): + ret += "(" + if member.membname in mark_offset: + ret += f"({{ _{member.membname}_offset = ctx->net_offset; " + ret += f"validate_{member.typ.typname}(ctx)" + if member.membname in mark_offset: + ret += "; })" + if should_save_value(typ, member): + nbytes = member.static_size + assert nbytes + if nbytes == 1: + ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + else: + ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" + if member.in_versions != typ.in_versions: + ret += " )" + ret += "\n" + + # Pass 4 - validate ,max= and ,val= constraints + for member in typ.members: + + def lookup_sym(sym: str) -> str: + match sym: + case "end": + return "ctx->net_offset" + case _: + assert sym.startswith("&") + return f"_{sym[1:]}_offset" + + if member.max: + assert member.static_size + nbits = member.static_size * 8 + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' + if member.val: + assert member.static_size + nbits = member.static_size * 8 + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' + + ret += cutil.ifdef_pop(1) + ret += "\t ;\n" + ret += "}\n" + ret += cutil.ifdef_pop(0) + return ret |