diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-01-19 15:53:46 -0700 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-01-19 15:53:46 -0700 |
commit | 104ea21b497171f5a1c4ba80d82337da3f7c2632 (patch) | |
tree | 9b5a167833b9caa4f8f829c9bc7a3711a1cd837a /lib9p/idl.gen | |
parent | a35db3be439c9a27f0763036cf3d4992ccf893eb (diff) | |
parent | 0ab9da9bc3c6cdaef00b7202ba03eff917b44c95 (diff) |
Merge branch 'lukeshu/9p-tidy'
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 219 |
1 files changed, 129 insertions, 90 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index 47ca49a..7c79a28 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -23,6 +23,9 @@ import idl idprefix = "lib9p_" +u32max = (1 << 32) - 1 +u64max = (1 << 64) - 1 + def tab_ljust(s: str, width: int) -> str: cur = len(s.expandtabs(tabsize=8)) @@ -67,9 +70,11 @@ def c_ver_cond(versions: set[str]) -> str: return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" -def c_typename(typ: idl.Type) -> str: +def c_typename(typ: idl.Type, parent: idl.Type | None = None) -> str: match typ: case idl.Primitive(): + if typ.value == 1 and parent and parent.name in ["d", "s"]: # SPECIAL + return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" case idl.Number(): return f"{idprefix}{typ.name}_t" @@ -88,17 +93,17 @@ def c_expr(expr: idl.Expr) -> str: for tok in expr.tokens: match tok: case idl.ExprOp(): - ret += [tok.op] + ret.append(tok.op) case idl.ExprLit(): - ret += [str(tok.val)] + ret.append(str(tok.val)) case idl.ExprSym(name="end"): - ret += ["ctx->net_offset"] + ret.append("ctx->net_offset") case idl.ExprSym(name="s32_max"): - ret += ["INT32_MAX"] + ret.append("INT32_MAX") case idl.ExprSym(name="s64_max"): - ret += ["INT64_MAX"] + ret.append("INT64_MAX") case idl.ExprSym(): - ret += [f"_{tok.name[1:]}_offset"] + ret.append(f"_{tok.name[1:]}_offset") return " ".join(ret) @@ -109,7 +114,7 @@ def ifdef_push(n: int, _newval: str) -> str: # Grow the stack as needed global _ifdef_stack while len(_ifdef_stack) < n: - _ifdef_stack += [None] + _ifdef_stack.append(None) # Set some variables parentval: str | None = None @@ -204,8 +209,6 @@ enum {idprefix}version {{ ret += ifdef_pop(0) ret += f"\t{c_ver_enum('NUM')},\n" ret += "};\n" - ret += "\n" - ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" ret += """ /* enum msg_type **************************************************************/ @@ -222,9 +225,43 @@ enum {idprefix}version {{ ret += """ /* payload types **************************************************************/ """ + + def per_version_comment( + typ: idl.Type, fn: typing.Callable[[idl.Type, str], str] + ) -> str: + lines: dict[str, str] = {} + for version in sorted(typ.in_versions): + lines[version] = fn(typ, version) + if len(set(lines.values())) == 1: + for _, line in lines.items(): + return f"/* {line} */\n" + assert False + else: + ret = "" + v_width = max(len(c_ver_enum(v)) for v in typ.in_versions) + for version, line in lines.items(): + ret += f"/* {c_ver_enum(version).ljust(v_width)}: {line} */\n" + return ret + for typ in topo_sorted(typs): ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + + def sum_size(typ: idl.Type, version: str) -> str: + min_size = typ.min_size(version) + max_size = typ.max_size(version) + assert min_size <= max_size and max_size < u64max + ret = "" + if min_size == max_size: + ret += f"size = {min_size:,}" + else: + ret += f"min_size = {min_size:,} ; max_size = {max_size:,}" + if max_size > u32max: + ret += " (warning: >UINT32_MAX)" + return ret + + ret += per_version_comment(typ, sum_size) + match typ: case idl.Number(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" @@ -281,16 +318,13 @@ enum {idprefix}version {{ continue ret += "\n" - typewidth = max(len(c_typename(m.typ)) for m in typ.members) + typewidth = max(len(c_typename(m.typ, typ)) for m in typ.members) for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - c_type = c_typename(member.typ) - if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL - c_type = "char" - ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ, typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) @@ -330,6 +364,38 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: def unused(arg: str) -> str: return f"LM_UNUSED({arg})" + for v in sorted(versions): + ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" + ret += "#else\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" + ret += "#endif\n" + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + + def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: + ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" + for ver in ["unknown", *sorted(versions)]: + if ver != "unknown": + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {{\n" + for n in range(*rng): + xmsg: idl.Message | None = id2typ.get(n, None) + if xmsg: + if ver == "unknown": # SPECIAL + if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: + xmsg = None + else: + if ver not in xmsg.in_versions: + xmsg = None + if xmsg: + ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n" + ret += "\t},\n" + ret += ifdef_pop(0) + ret += "};\n" + return ret + ret += "\n" ret += "/**\n" ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n" @@ -338,18 +404,12 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: ret += " * several version checks together.\n" ret += " */\n" ret += "#define is_ver(ctx, ver) _is_ver_##ver(ctx->ctx->version)\n" - for v in sorted(versions): - ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" - ret += "#else\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" - ret += "#endif\n" # strings ################################################################## ret += f""" /* strings ********************************************************************/ -static const char *version_strs[{c_ver_enum('NUM')}] = {{ +const char *_lib9p_table_ver_name[{c_ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: @@ -357,15 +417,10 @@ static const char *version_strs[{c_ver_enum('NUM')}] = {{ ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' ret += ifdef_pop(0) ret += "};\n" - ret += f""" -const char *{idprefix}version_str(enum {idprefix}version ver) {{ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wtype-limits" - assert(0 <= ver && ver < {c_ver_enum('NUM')}); -#pragma GCC diagnostic pop - return version_strs[ver]; -}} -""" + + ret += "\n" + ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" + ret += msg_table("msg", "name", "char *", (0, 0x100, 1)) # bitmasks ################################################################# ret += f""" @@ -443,7 +498,7 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_4(ctx))\n" ret += "\t\treturn true;\n" - ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" + ret += "\tuint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]);\n" ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" ret += "}\n" continue @@ -453,8 +508,8 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_2(ctx))\n" ret += "\t\treturn true;\n" - ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" - ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" + ret += "\tuint16_t len = uint16le_decode(&ctx->net_bytes[base_offset]);\n" + ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)))\n" ret += "\t\treturn true;\n" ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' @@ -471,7 +526,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += ( f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n" ) - ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + if typ.static_size == 1: + ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + else: + ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += f"\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" @@ -507,7 +565,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: assert prev_size - ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + if prev_size == 1: + ret += f"_validate_list(ctx, ctx->net_bytes[ctx->net_offset-1], validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + else: + ret += f"_validate_list(ctx, uint{prev_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" else: if member.max or member.val: ret += "(" @@ -517,10 +578,12 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, if member.name in mark_offset: ret += "; })" if member.max or member.val: - bytes = member.static_size - assert bytes - bits = bytes * 8 - ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" + nbytes = member.static_size + assert nbytes + if nbytes == 1: + ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + else: + ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" @@ -551,22 +614,22 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, /* unmarshal_* ****************************************************************/ LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { -\t*out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); +\t*out = ctx->net_bytes[ctx->net_offset]; \tctx->net_offset += 1; } LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { -\t*out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); +\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 2; } LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { -\t*out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); +\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 4; } LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { -\t*out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); +\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 8; } """ @@ -600,9 +663,10 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "{\n" ret += prefix ret += f"out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" - ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" if typ.name in ["d", "s"]: # SPECIAL + # Special-case is that we cast from `char` to `uint8_t`. ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" else: ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" @@ -612,9 +676,6 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += ( f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" ) - if typ.name == "s": # SPECIAL - ret += "\tctx->extra++;\n" - ret += "\tout->utf8[out->len] = '\\0';\n" ret += ifdef_pop(1) ret += "}\n" ret += ifdef_pop(0) @@ -642,7 +703,7 @@ LM_ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { \tif (ctx->net_offset + 2 > ctx->ctx->max_msg_size) \t\treturn _marshal_too_large(ctx); -\tencode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); +\tuint16le_encode(&ctx->net_bytes[ctx->net_offset], *val); \tctx->net_offset += 2; \treturn false; } @@ -650,7 +711,7 @@ LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { \tif (ctx->net_offset + 4 > ctx->ctx->max_msg_size) \t\treturn true; -\tencode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); +\tuint32le_encode(&ctx->net_bytes[ctx->net_offset], *val); \tctx->net_offset += 4; \treturn false; } @@ -658,7 +719,7 @@ LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { \tif (ctx->net_offset + 8 > ctx->ctx->max_msg_size) \t\treturn true; -\tencode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); +\tuint64le_encode(&ctx->net_bytes[ctx->net_offset], *val); \tctx->net_offset += 8; \treturn false; } @@ -705,7 +766,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) ret += f"({{ _{member.name}_offset = ctx->net_offset; " if member.cnt: ret += "({ bool err = false;\n" - ret += f"\t for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n" + ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n" ret += "\t \terr = " if typ.name in ["d", "s"]: # SPECIAL # Special-case is that we cast from `char` to `uint8_t`. @@ -732,22 +793,21 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) if member.val: assert member.static_size ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n" + if member.static_size == 1: + ret += f"\t || ({{ ctx->net_bytes[_{member.name}_offset] = {c_expr(member.val)}; false; }})\n" + else: + ret += f"\t || ({{ uint{member.static_size*8}le_encode(&ctx->net_bytes[_{member.name}_offset], {c_expr(member.val)}); false; }})\n" ret += ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += ifdef_pop(0) - # tables / exports ######################################################### + # function tables ########################################################## ret += """ -/* tables / exports ***********************************************************/ -""" - id2typ: dict[int, idl.Message] = {} - for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: - id2typ[msg.msgid] = msg +/* function tables ************************************************************/ - ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" +""" ret += c_macro( f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{", f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),", @@ -760,35 +820,14 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,", f"\t}}", ) - - tables = [ - ("msg", "name", "char *", (0, 0x100, 1)), - ("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)), - ("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)), - ("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)), - ("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)), - ] - for grp, meth, tentry, rng in tables: - ret += "\n" - ret += f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" - for ver in ["unknown", *sorted(versions)]: - if ver != "unknown": - ret += ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t[{c_ver_enum(ver)}] = {{\n" - for n in range(*rng): - xmsg: idl.Message | None = id2typ.get(n, None) - if xmsg: - if ver == "unknown": # SPECIAL - if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: - xmsg = None - else: - if ver not in xmsg.in_versions: - xmsg = None - if xmsg: - ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n" - ret += "\t},\n" - ret += ifdef_pop(0) - ret += "};\n" + ret += "\n" + ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)) + ret += "\n" + ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)) + ret += "\n" + ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)) + ret += "\n" + ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)) ret += f""" LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{ |