From 652098de8dce01af90a65e0efa48dc156abc6329 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Mon, 3 Feb 2025 20:31:44 -0700 Subject: lib9p: idl.gen: Simplify special cases It wasn't validating d_signed length. --- lib9p/9p.generated.c | 59 +++++++++++++++++++++-------------------- lib9p/idl.gen | 74 ++++++++++++++++++---------------------------------- 2 files changed, 56 insertions(+), 77 deletions(-) diff --git a/lib9p/9p.generated.c b/lib9p/9p.generated.c index 8e3e26b..a166fb2 100644 --- a/lib9p/9p.generated.c +++ b/lib9p/9p.generated.c @@ -354,10 +354,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, return false; } -#define validate_1(ctx) _validate_size_net(ctx, 1) -#define validate_2(ctx) _validate_size_net(ctx, 2) -#define validate_4(ctx) _validate_size_net(ctx, 4) -#define validate_8(ctx) _validate_size_net(ctx, 8) +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } #if CONFIG_9P_ENABLE_9P2000 || CONFIG_9P_ENABLE_9P2000_L || CONFIG_9P_ENABLE_9P2000_e || CONFIG_9P_ENABLE_9P2000_u LM_ALWAYS_INLINE static bool validate_tag(struct _validate_ctx *ctx) { @@ -369,31 +369,30 @@ LM_ALWAYS_INLINE static bool validate_fid(struct _validate_ctx *ctx) { } LM_ALWAYS_INLINE static bool validate_d(struct _validate_ctx *ctx) { - uint32_t base_offset = ctx->net_offset; - if (validate_4(ctx)) - return true; - uint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]); - return _validate_size_net(ctx, len) || _validate_size_host(ctx, len); + uint32_t len; + return false + || (validate_4(ctx) || ({ len = uint32le_decode(&ctx->net_bytes[ctx->net_offset-4]); false; })) + || _validate_list(ctx, len, validate_1, sizeof(uint8_t)) + ; } LM_ALWAYS_INLINE static bool validate_d_signed(struct _validate_ctx *ctx) { - uint32_t base_offset = ctx->net_offset; - if (validate_4(ctx)) - return true; - uint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]); - return _validate_size_net(ctx, len) || _validate_size_host(ctx, len); + uint32_t len; + return false + || (validate_4(ctx) || ({ len = uint32le_decode(&ctx->net_bytes[ctx->net_offset-4]); false; })) + || _validate_list(ctx, len, validate_1, sizeof(uint8_t)) + || ({ uint32_t max = INT32_MAX; (((uint32_t)len) > max) && + lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "len value is too large (%"PRIu32" > %"PRIu32")", len, max); }) + ; } LM_ALWAYS_INLINE static bool validate_s(struct _validate_ctx *ctx) { - uint32_t base_offset = ctx->net_offset; - if (validate_2(ctx)) - return true; - uint16_t len = uint16le_decode(&ctx->net_bytes[base_offset]); - if (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len))) - return true; - if (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len)) - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); - return false; + uint16_t len; + return false + || (validate_2(ctx) || ({ len = uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || _validate_list(ctx, len, validate_1, sizeof(uint8_t)) + || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }) + ; } #endif /* CONFIG_9P_ENABLE_9P2000 || CONFIG_9P_ENABLE_9P2000_L || CONFIG_9P_ENABLE_9P2000_e || CONFIG_9P_ENABLE_9P2000_u */ @@ -1086,7 +1085,7 @@ LM_FLATTEN static bool validate_Twalk(struct _validate_ctx *ctx) { || validate_fid(ctx) || validate_fid(ctx) || (validate_2(ctx) || ({ nwname = uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]); false; })) - || _validate_list(ctx, uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]), validate_s, sizeof(struct lib9p_s)) + || _validate_list(ctx, nwname, validate_s, sizeof(struct lib9p_s)) || ({ uint32_t exp = ctx->net_offset - _size_offset; (((uint32_t)size) != exp) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "size value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)size, exp); }) || ({ uint8_t exp = 110; (((uint8_t)typ) != exp) && @@ -1270,14 +1269,15 @@ LM_FLATTEN static bool validate_Tunlinkat(struct _validate_ctx *ctx) { LM_FLATTEN static bool validate_Tsread(struct _validate_ctx *ctx) { uint32_t size; uint8_t typ; + uint16_t nwname; uint32_t _size_offset; return false || (({ _size_offset = ctx->net_offset; validate_4(ctx); }) || ({ size = uint32le_decode(&ctx->net_bytes[ctx->net_offset-4]); false; })) || (validate_1(ctx) || ({ typ = ctx->net_bytes[ctx->net_offset-1]; false; })) || validate_tag(ctx) || validate_4(ctx) - || validate_2(ctx) - || _validate_list(ctx, uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]), validate_s, sizeof(struct lib9p_s)) + || (validate_2(ctx) || ({ nwname = uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || _validate_list(ctx, nwname, validate_s, sizeof(struct lib9p_s)) || ({ uint32_t exp = ctx->net_offset - _size_offset; (((uint32_t)size) != exp) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "size value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)size, exp); }) || ({ uint8_t exp = 152; (((uint8_t)typ) != exp) && @@ -1288,14 +1288,15 @@ LM_FLATTEN static bool validate_Tsread(struct _validate_ctx *ctx) { LM_FLATTEN static bool validate_Tswrite(struct _validate_ctx *ctx) { uint32_t size; uint8_t typ; + uint16_t nwname; uint32_t _size_offset; return false || (({ _size_offset = ctx->net_offset; validate_4(ctx); }) || ({ size = uint32le_decode(&ctx->net_bytes[ctx->net_offset-4]); false; })) || (validate_1(ctx) || ({ typ = ctx->net_bytes[ctx->net_offset-1]; false; })) || validate_tag(ctx) || validate_4(ctx) - || validate_2(ctx) - || _validate_list(ctx, uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]), validate_s, sizeof(struct lib9p_s)) + || (validate_2(ctx) || ({ nwname = uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || _validate_list(ctx, nwname, validate_s, sizeof(struct lib9p_s)) || validate_d(ctx) || ({ uint32_t exp = ctx->net_offset - _size_offset; (((uint32_t)size) != exp) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "size value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)size, exp); }) @@ -1631,7 +1632,7 @@ LM_FLATTEN static bool validate_Rwalk(struct _validate_ctx *ctx) { || (validate_1(ctx) || ({ typ = ctx->net_bytes[ctx->net_offset-1]; false; })) || validate_tag(ctx) || (validate_2(ctx) || ({ nwqid = uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]); false; })) - || _validate_list(ctx, uint16le_decode(&ctx->net_bytes[ctx->net_offset-2]), validate_qid, sizeof(struct lib9p_qid)) + || _validate_list(ctx, nwqid, validate_qid, sizeof(struct lib9p_qid)) || ({ uint32_t exp = ctx->net_offset - _size_offset; (((uint32_t)size) != exp) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "size value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)size, exp); }) || ({ uint8_t exp = 111; (((uint8_t)typ) != exp) && diff --git a/lib9p/idl.gen b/lib9p/idl.gen index 558592f..8782771 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -66,12 +66,10 @@ def c_ver_cond(versions: set[str]) -> str: return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" -def c_typename(typ: idl.Type, parent: idl.Type | None = None) -> str: +def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: match typ: case idl.Primitive(): - if ( - typ.value == 1 and parent and parent.name in ["d", "d_signed", "s"] - ): # SPECIAL + if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" case idl.Number(): @@ -322,13 +320,13 @@ enum {idprefix}version {{ continue ret += "\n" - typewidth = max(len(c_typename(m.typ, typ)) for m in typ.members) + typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ, typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) @@ -381,7 +379,7 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: - if ver == "unknown": # SPECIAL + if ver == "unknown": # SPECIAL (initialization) if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: xmsg = None else: @@ -484,10 +482,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, \treturn false; } -#define validate_1(ctx) _validate_size_net(ctx, 1) -#define validate_2(ctx) _validate_size_net(ctx, 2) -#define validate_4(ctx) _validate_size_net(ctx, 4) -#define validate_8(ctx) _validate_size_net(ctx, 8) +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" @@ -496,31 +494,6 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" - if typ.name == "d" or typ.name == "d_signed": # SPECIAL - # Optimize... maybe the compiler could figure out to do - # this, but let's make it obvious. - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_4(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]);\n" - ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" - ret += "}\n" - continue - if typ.name == "s": # SPECIAL - # Add an extra nul-byte on the host, and validate UTF-8 - # (also, similar optimization to "d"). - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_2(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint16_t len = uint16le_decode(&ctx->net_bytes[base_offset]);\n" - ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)))\n" - ret += "\t\treturn true;\n" - ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" - ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' - ret += "\treturn false;\n" - ret += "}\n" - continue - match typ: case idl.Number(): ret += f"\treturn validate_{typ.prim.name}(ctx);\n" @@ -543,9 +516,18 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += "}\n" continue + def should_save_value(member: idl.StructMember) -> bool: + nonlocal typ + assert isinstance(typ, idl.Struct) + return bool( + member.max + or member.val + or any(m.cnt == member for m in typ.members) + ) + # Pass 1 - declare value variables for member in typ.members: - if member.max or member.val: + if should_save_value(member): ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ)} {member.name};\n" ret += ifdef_pop(1) @@ -561,27 +543,24 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, # Pass 3 - main pass ret += "\treturn false\n" - prev_size: int | None = None for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: - assert prev_size - if prev_size == 1: - ret += f"_validate_list(ctx, ctx->net_bytes[ctx->net_offset-1], validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" - else: - ret += f"_validate_list(ctx, uint{prev_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + if typ.name == "s": # SPECIAL (string) + ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' else: - if member.max or member.val: + if should_save_value(member): ret += "(" if member.name in mark_offset: ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" if member.name in mark_offset: ret += "; })" - if member.max or member.val: + if should_save_value(member): nbytes = member.static_size assert nbytes if nbytes == 1: @@ -591,7 +570,6 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, if member.in_versions != typ.in_versions: ret += " )" ret += "\n" - prev_size = member.static_size # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: @@ -669,7 +647,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += f"out->{member.name} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" - if typ.name in ["d", "d_signed", "s"]: # SPECIAL + if member.typ.static_size == 1: # SPECIAL (string) # Special-case is that we cast from `char` to `uint8_t`. ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" else: @@ -772,7 +750,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) ret += "({ bool err = false;\n" ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n" ret += "\t \terr = " - if typ.name in ["d", "d_signed", "s"]: # SPECIAL + if member.typ.static_size == 1: # SPECIAL (string) # Special-case is that we cast from `char` to `uint8_t`. ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n" else: -- cgit v1.2.3-2-g168b