# lib9p/protogen/c_validate.py - Generate C validation functions # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import idl from . import c9util, cutil, idlutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # pylint: disable=unused-variable __all__ = ["gen_c_validate"] def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: return bool(member.max or member.val or any(m.cnt == member for m in typ.members)) def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str: ret = """ /* validate_* *****************************************************************/ LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { \tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) \t\t/* If needed-net-size overflowed uint32_t, then \t\t * there's no way that actual-net-size will live up to \t\t * that. */ \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \tif (ctx->net_offset > ctx->net_size) \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \treturn false; } LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { \tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) \t\t/* If needed-host-size overflowed size_t, then there's \t\t * no way that actual-net-size will live up to \t\t * that. */ \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \treturn false; } LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { \tfor (size_t i = 0; i < cnt; i++) \t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) \t\t\treturn true; \treturn false; } LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ for typ in idlutil.topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = ( c9util.arg_unused if (isinstance(typ, idl.Struct) and not typ.members) else c9util.arg_used ) ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: case idl.Number(): ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" if typ.static_size == 1: ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += "\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" case idl.Struct(): # and idl.Message() if len(typ.members) == 0: ret += "\treturn false;\n" ret += "}\n" continue # Pass 1 - declare value variables for member in typ.members: if should_save_value(typ, member): ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" ret += cutil.ifdef_pop(1) # Pass 2 - declare offset variables mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): if tok.symname[1:] not in mark_offset: ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" mark_offset.add(tok.symname[1:]) # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += "\t || " if member.in_versions != typ.in_versions: ret += "( " + c9util.ver_cond(member.in_versions) + " && " if member.cnt is not None: if member.typ.static_size == 1: # SPECIAL (zerocopy) ret += f"_validate_size_net(ctx, {member.cnt.membname})" else: ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" if typ.typname == "s": # SPECIAL (string) ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' else: if should_save_value(typ, member): ret += "(" if member.membname in mark_offset: ret += f"({{ _{member.membname}_offset = ctx->net_offset; " ret += f"validate_{member.typ.typname}(ctx)" if member.membname in mark_offset: ret += "; })" if should_save_value(typ, member): nbytes = member.static_size assert nbytes if nbytes == 1: ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" else: ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: def lookup_sym(sym: str) -> str: match sym: case "end": return "ctx->net_offset" case _: assert sym.startswith("&") return f"_{sym[1:]}_offset" if member.max: assert member.static_size nbits = member.static_size * 8 ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' ret += cutil.ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += cutil.ifdef_pop(0) return ret