diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-23 01:22:27 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-23 03:05:06 -0600 |
commit | 77249bb45c44ec88c96cd00da0805e1a58a1bfd6 (patch) | |
tree | 2e35e2de507f49c38eba94153394dbde666e9cd4 | |
parent | c1a1f287ed883bed049627da0fd8395197ebf876 (diff) |
lib9p: protogen: pull c9util.py out of __init__.py
-rw-r--r-- | lib9p/protogen/__init__.py | 264 | ||||
-rw-r--r-- | lib9p/protogen/c9util.py | 109 |
2 files changed, 214 insertions, 159 deletions
diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 8a9a371..73542a2 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -12,7 +12,7 @@ import typing import idl -from . import cutil +from . import c9util, cutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in @@ -22,70 +22,6 @@ from . import cutil # pylint: disable=unused-variable __all__ = ["main"] -# Utilities #################################################################### - -idprefix = "lib9p_" - - -def add_prefix(p: str, s: str) -> str: - if s.startswith("_"): - return "_" + p + s[1:] - return p + s - - -def c_ver_enum(ver: str) -> str: - return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" - - -def c_ver_ifdef(versions: typing.Collection[str]) -> str: - return " || ".join( - f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions) - ) - - -def c_ver_cond(versions: typing.Collection[str]) -> str: - if len(versions) == 1: - v = next(v for v in versions) - return f"is_ver(ctx, {v.replace('.', '_')})" - return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" - - -def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: - match typ: - case idl.Primitive(): - if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) - return "[[gnu::nonstring]] char" - return f"uint{typ.value*8}_t" - case idl.Number(): - return f"{idprefix}{typ.typname}_t" - case idl.Bitfield(): - return f"{idprefix}{typ.typname}_t" - case idl.Message(): - return f"struct {idprefix}msg_{typ.typname}" - case idl.Struct(): - return f"struct {idprefix}{typ.typname}" - case _: - raise ValueError(f"not a type: {typ.__class__.__name__}") - - -def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: - ret: list[str] = [] - for tok in expr.tokens: - match tok: - case idl.ExprOp(): - ret.append(tok.op) - case idl.ExprLit(): - ret.append(str(tok.val)) - case idl.ExprSym(symname="s32_max"): - ret.append("INT32_MAX") - case idl.ExprSym(symname="s64_max"): - ret.append("INT64_MAX") - case idl.ExprSym(): - ret.append(lookup_sym(tok.symname)) - case _: - assert False - return " ".join(ret) - # topo_sorted() ################################################################ @@ -343,13 +279,13 @@ def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: """ for ver in sorted(versions): ret += "\n" - ret += f"#ifndef {c_ver_ifdef({ver})}\n" - ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n" + ret += f"#ifndef {c9util.ver_ifdef({ver})}\n" + ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n" if ver == "9P2000.e": # SPECIAL (9P2000.e) ret += "#else\n" - ret += f"\t#if {c_ver_ifdef({ver})}\n" + ret += f"\t#if {c9util.ver_ifdef({ver})}\n" ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" - ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += "\t\t#endif\n" ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" ret += "\t#endif\n" @@ -358,31 +294,31 @@ def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: ret += f""" /* enum version ***************************************************************/ -enum {idprefix}version {{ +enum {c9util.ident('version')} {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t{c_ver_enum(ver)}," + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t{c9util.ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += cutil.ifdef_pop(0) - ret += f"\t{c_ver_enum('NUM')},\n" + ret += f"\t{c9util.ver_enum('NUM')},\n" ret += "};\n" ret += """ /* enum msg_type **************************************************************/ """ - ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n" namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] - ret += cutil.ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += f"\t{idprefix.upper()}TYP_{msg.typname:<{namewidth}} = {msg.msgid},\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions)) + ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n" ret += cutil.ifdef_pop(0) ret += "};\n" @@ -402,14 +338,14 @@ enum {idprefix}version {{ assert False else: ret = "" - v_width = max(len(c_ver_enum(v)) for v in typ.in_versions) + v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions) for version, line in lines.items(): - ret += f"/* {c_ver_enum(version):<{v_width}}: {line} */\n" + ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" return ret for typ in topo_sorted(typs): ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) def sum_size(typ: idl.UserType, version: str) -> str: sz = get_buffer_size(typ, version) @@ -432,13 +368,13 @@ enum {idprefix}version {{ match typ: case idl.Number(): - ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" - prefix = f"{idprefix.upper()}{typ.typname.upper()}_" + ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" + prefix = f"{c9util.IDENT(typ.typname)}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): - ret += f"#define {prefix}{name:<{namewidth}} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" + ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n" case idl.Bitfield(): - ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" + ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" def bitname(val: idl.Bit | idl.BitAlias) -> str: s = val.bitname @@ -456,7 +392,7 @@ enum {idprefix}version {{ s = f"_{s}_{n}" case idl.Bit(cat=idl.BitCat.UNUSED): return "" - return add_prefix(f"{idprefix.upper()}{typ.typname.upper()}_", s) + return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s)) namewidth = max( len(bitname(val)) for val in [*typ.bits, *typ.names.values()] @@ -467,7 +403,7 @@ enum {idprefix}version {{ vers = bit.in_versions if bit.cat == idl.BitCat.UNUSED: vers = typ.in_versions - ret += cutil.ifdef_push(2, c_ver_ifdef(vers)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers)) # It is important all of the `beg` strings have # the same length. @@ -486,7 +422,7 @@ enum {idprefix}version {{ c_name = bitname(bit) c_val = f"1<<{bit.num}" - ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" + ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" if aliases := [ alias for alias in typ.names.values() @@ -495,7 +431,7 @@ enum {idprefix}version {{ ret += "\n" for alias in aliases: - ret += cutil.ifdef_push(2, c_ver_ifdef(alias.in_versions)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions)) end = "" if cutil.ifdef_leaf_is_noop(): @@ -505,23 +441,23 @@ enum {idprefix}version {{ c_name = bitname(alias) c_val = alias.val - ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" + ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" ret += cutil.ifdef_pop(1) del bitname case idl.Struct(): # and idl.Message(): - ret += c_typename(typ) + " {" + ret += c9util.typename(typ) + " {" if not typ.members: ret += "};\n" continue ret += "\n" - typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) + typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" ret += cutil.ifdef_pop(1) ret += "};\n" del typ @@ -531,7 +467,7 @@ enum {idprefix}version {{ /* containers *****************************************************************/ """ ret += "\n" - ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n" + ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n" tmsg_max_iov: dict[str, int] = {} tmsg_max_copy: dict[str, int] = {} @@ -570,7 +506,7 @@ enum {idprefix}version {{ directive = "if" seen_e = False # SPECIAL (9P2000.e) for maxval in sorted(inv, reverse=True): - ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" + ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n" indent = 1 if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) typ = next(typ for typ in typs if typ.typname == "Tswrite") @@ -582,11 +518,11 @@ enum {idprefix}version {{ maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" case _: assert False - ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" - ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" + ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n" + ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n" ret += "\t#else\n" indent += 1 - ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" + ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n" if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) ret += "\t#endif\n" if "9P2000.e" in inv[maxval]: @@ -595,17 +531,17 @@ enum {idprefix}version {{ ret += "#endif\n" ret += "\n" - ret += f"struct {idprefix}Tmsg_send_buf {{\n" + ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" - ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" - ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" + ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n" + ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n" ret += "};\n" ret += "\n" - ret += f"struct {idprefix}Rmsg_send_buf {{\n" + ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" - ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" - ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" + ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n" + ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n" ret += "};\n" return ret @@ -647,11 +583,11 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: id2typ[msg.msgid] = msg def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: - ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" + ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" for ver in ["unknown", *sorted(versions)]: if ver != "unknown": - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t[{c_ver_enum(ver)}] = {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: @@ -670,14 +606,16 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: for v in sorted(versions): ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" + ret += ( + f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" + ) ret += "#else\n" ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" ret += "#endif\n" ret += "\n" ret += "/**\n" - ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n" - ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n" + ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" + ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" ret += " * several version checks together.\n" ret += " */\n" @@ -687,17 +625,17 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: ret += f""" /* strings ********************************************************************/ -const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ +const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" - ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" + ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) # bitmasks ################################################################# @@ -708,13 +646,13 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ if not isinstance(typ, idl.Bitfield): continue ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): - ret += cutil.ifdef_push(2, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) ret += ( - f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( ( "1" @@ -778,7 +716,7 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: @@ -787,11 +725,11 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" - ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" + ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" if typ.static_size == 1: - ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: - ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += "\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" @@ -804,8 +742,8 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val # Pass 1 - declare value variables for member in typ.members: if should_save_value(typ, member): - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ)} {member.membname};\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" ret += cutil.ifdef_pop(1) # Pass 2 - declare offset variables @@ -820,15 +758,15 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += "\t || " if member.in_versions != typ.in_versions: - ret += "( " + c_ver_cond(member.in_versions) + " && " + ret += "( " + c9util.ver_cond(member.in_versions) + " && " if member.cnt is not None: if member.typ.static_size == 1: # SPECIAL (zerocopy) ret += f"_validate_size_net(ctx, {member.cnt.membname})" else: - ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))" + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" if typ.typname == "s": # SPECIAL (string) ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' else: @@ -864,14 +802,14 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val if member.max: assert member.static_size nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' ret += cutil.ifdef_pop(1) @@ -907,18 +845,18 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" match typ: case idl.Number(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" for member in typ.members: - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) if member.val: ret += f"\tctx->net_offset += {member.static_size};\n" continue @@ -926,7 +864,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o prefix = "\t" if member.in_versions != typ.in_versions: - ret += "if ( " + c_ver_cond(member.in_versions) + " ) " + ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " prefix = "\t\t" if member.cnt: if member.in_versions != typ.in_versions: @@ -1041,14 +979,14 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o oneline.append(f"({cnt.c_str(root)})*{sub.static}") continue loopvar = chr(ord("i") + loop_depth) - multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" + multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" multiline += sub.gen_c( "", dstvar, root, indent_depth + 1, loop_depth + 1 ) multiline += f"{'\t'*indent_depth}}}\n" for vers, sub in self.cond.items(): - multiline += cutil.ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) - multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n" + multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) + multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) multiline += f"{'\t'*indent_depth}}}\n" multiline += cutil.ifdef_pop(indent_depth) @@ -1143,8 +1081,8 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o continue assert isinstance(typ, idl.Struct) ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) @@ -1235,10 +1173,10 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: ret += cutil.ifdef_push( - ifdef_depth + 1, c_ver_ifdef(child.in_versions) + ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) ) ifdef_depth += 1 - ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n" + ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" stack.append((path, True)) if child.cnt: cnt_path = path.parent().add(child.cnt) @@ -1249,7 +1187,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" return WalkCmd.KEEP_GOING, pop loopvar = chr(ord("i") + loopdepth - 1) - ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" + ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" stack.append((path, False)) if not isinstance(child.typ, idl.Struct): if child.val: @@ -1264,7 +1202,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o + sym ) - val = c_expr(child.val, lookup_sym) + val = c9util.idl_expr(child.val, lookup_sym) else: val = path.c_str("val->") if isinstance(child.typ, idl.Bitfield): @@ -1287,45 +1225,53 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o """ ret += "\n" - ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" + ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" rerror = next(typ for typ in typs if typ.typname == "Rerror") - ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) + ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) for ver in sorted(versions): - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" ret += cutil.macro( - f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" - f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" + f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" + f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" f"\t\t.validate = validate_##typ,\n" f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" f"\t}}\n" ) ret += cutil.macro( - f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" + f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" f"\t}}\n" ) ret += "\n" - ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)) + ret += msg_table( + "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) + ) ret += "\n" - ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)) + ret += msg_table( + "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) + ) ret += "\n" - ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)) + ret += msg_table( + "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) + ) ret += "\n" - ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)) + ret += msg_table( + "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) + ) ret += f""" -LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{ +LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ \treturn validate_stat(ctx); }} -LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{ +LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ \tunmarshal_stat(ctx, out); }} -LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{ +LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ \treturn marshal_stat(ctx, val); }} """ diff --git a/lib9p/protogen/c9util.py b/lib9p/protogen/c9util.py new file mode 100644 index 0000000..e7ad999 --- /dev/null +++ b/lib9p/protogen/c9util.py @@ -0,0 +1,109 @@ +# lib9p/protogen/c9util.py - Utilities for generating lib9p-specific C +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import typing + +import idl + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + +# pylint: disable=unused-variable +__all__ = [ + "add_prefix", + "ident", + "Ident", + "IDENT", + "ver_enum", + "ver_ifdef", + "ver_cond", + "typename", + "idl_expr", +] + +# idents ####################################################################### + + +def add_prefix(p: str, s: str) -> str: + if s.startswith("_"): + return "_" + p + s[1:] + return p + s + + +def _ident(p: str, s: str) -> str: + return add_prefix(p, s.replace(".", "_")) + + +def ident(s: str) -> str: + return _ident("lib9p_", s) + + +def Ident(s: str) -> str: + return _ident("lib9p_".upper(), s) + + +def IDENT(s: str) -> str: + return _ident("lib9p_", s).upper() + + +# versions ##################################################################### + + +def ver_enum(ver: str) -> str: + return Ident("VER_" + ver) + + +def ver_ifdef(versions: typing.Collection[str]) -> str: + return " || ".join( + f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions) + ) + + +def ver_cond(versions: typing.Collection[str]) -> str: + if len(versions) == 1: + v = next(v for v in versions) + return f"is_ver(ctx, {v.replace('.', '_')})" + return "( " + (" || ".join(ver_cond({v}) for v in sorted(versions))) + " )" + + +# misc ######################################################################### + + +def typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: + match typ: + case idl.Primitive(): + if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) + return "[[gnu::nonstring]] char" + return f"uint{typ.value*8}_t" + case idl.Number(): + return ident(f"{typ.typname}_t") + case idl.Bitfield(): + return ident(f"{typ.typname}_t") + case idl.Message(): + return f"struct {ident(f'msg_{typ.typname}')}" + case idl.Struct(): + return f"struct {ident(typ.typname)}" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def idl_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: + ret: list[str] = [] + for tok in expr.tokens: + match tok: + case idl.ExprOp(): + ret.append(tok.op) + case idl.ExprLit(): + ret.append(str(tok.val)) + case idl.ExprSym(symname="s32_max"): + ret.append("INT32_MAX") + case idl.ExprSym(symname="s64_max"): + ret.append("INT64_MAX") + case idl.ExprSym(): + ret.append(lookup_sym(tok.symname)) + case _: + assert False + return " ".join(ret) |