summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-23 01:22:27 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-23 03:05:06 -0600
commit77249bb45c44ec88c96cd00da0805e1a58a1bfd6 (patch)
tree2e35e2de507f49c38eba94153394dbde666e9cd4
parentc1a1f287ed883bed049627da0fd8395197ebf876 (diff)
lib9p: protogen: pull c9util.py out of __init__.py
-rw-r--r--lib9p/protogen/__init__.py264
-rw-r--r--lib9p/protogen/c9util.py109
2 files changed, 214 insertions, 159 deletions
diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py
index 8a9a371..73542a2 100644
--- a/lib9p/protogen/__init__.py
+++ b/lib9p/protogen/__init__.py
@@ -12,7 +12,7 @@ import typing
import idl
-from . import cutil
+from . import c9util, cutil
# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
@@ -22,70 +22,6 @@ from . import cutil
# pylint: disable=unused-variable
__all__ = ["main"]
-# Utilities ####################################################################
-
-idprefix = "lib9p_"
-
-
-def add_prefix(p: str, s: str) -> str:
- if s.startswith("_"):
- return "_" + p + s[1:]
- return p + s
-
-
-def c_ver_enum(ver: str) -> str:
- return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
-
-
-def c_ver_ifdef(versions: typing.Collection[str]) -> str:
- return " || ".join(
- f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
- )
-
-
-def c_ver_cond(versions: typing.Collection[str]) -> str:
- if len(versions) == 1:
- v = next(v for v in versions)
- return f"is_ver(ctx, {v.replace('.', '_')})"
- return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-
-
-def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
- match typ:
- case idl.Primitive():
- if typ.value == 1 and parent and parent.cnt: # SPECIAL (string)
- return "[[gnu::nonstring]] char"
- return f"uint{typ.value*8}_t"
- case idl.Number():
- return f"{idprefix}{typ.typname}_t"
- case idl.Bitfield():
- return f"{idprefix}{typ.typname}_t"
- case idl.Message():
- return f"struct {idprefix}msg_{typ.typname}"
- case idl.Struct():
- return f"struct {idprefix}{typ.typname}"
- case _:
- raise ValueError(f"not a type: {typ.__class__.__name__}")
-
-
-def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
- ret: list[str] = []
- for tok in expr.tokens:
- match tok:
- case idl.ExprOp():
- ret.append(tok.op)
- case idl.ExprLit():
- ret.append(str(tok.val))
- case idl.ExprSym(symname="s32_max"):
- ret.append("INT32_MAX")
- case idl.ExprSym(symname="s64_max"):
- ret.append("INT64_MAX")
- case idl.ExprSym():
- ret.append(lookup_sym(tok.symname))
- case _:
- assert False
- return " ".join(ret)
-
# topo_sorted() ################################################################
@@ -343,13 +279,13 @@ def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
"""
for ver in sorted(versions):
ret += "\n"
- ret += f"#ifndef {c_ver_ifdef({ver})}\n"
- ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n"
+ ret += f"#ifndef {c9util.ver_ifdef({ver})}\n"
+ ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n"
if ver == "9P2000.e": # SPECIAL (9P2000.e)
ret += "#else\n"
- ret += f"\t#if {c_ver_ifdef({ver})}\n"
+ ret += f"\t#if {c9util.ver_ifdef({ver})}\n"
ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n"
- ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
+ ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
ret += "\t\t#endif\n"
ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
ret += "\t#endif\n"
@@ -358,31 +294,31 @@ def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
ret += f"""
/* enum version ***************************************************************/
-enum {idprefix}version {{
+enum {c9util.ident('version')} {{
"""
fullversions = ["unknown = 0", *sorted(versions)]
verwidth = max(len(v) for v in fullversions)
for ver in fullversions:
if ver in versions:
- ret += cutil.ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t{c_ver_enum(ver)},"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f"\t{c9util.ver_enum(ver)},"
ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
ret += cutil.ifdef_pop(0)
- ret += f"\t{c_ver_enum('NUM')},\n"
+ ret += f"\t{c9util.ver_enum('NUM')},\n"
ret += "};\n"
ret += """
/* enum msg_type **************************************************************/
"""
- ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
+ ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n"
namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message))
for n in range(0x100):
if n not in id2typ:
continue
msg = id2typ[n]
- ret += cutil.ifdef_push(1, c_ver_ifdef(msg.in_versions))
- ret += f"\t{idprefix.upper()}TYP_{msg.typname:<{namewidth}} = {msg.msgid},\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions))
+ ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n"
ret += cutil.ifdef_pop(0)
ret += "};\n"
@@ -402,14 +338,14 @@ enum {idprefix}version {{
assert False
else:
ret = ""
- v_width = max(len(c_ver_enum(v)) for v in typ.in_versions)
+ v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions)
for version, line in lines.items():
- ret += f"/* {c_ver_enum(version):<{v_width}}: {line} */\n"
+ ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n"
return ret
for typ in topo_sorted(typs):
ret += "\n"
- ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions))
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
def sum_size(typ: idl.UserType, version: str) -> str:
sz = get_buffer_size(typ, version)
@@ -432,13 +368,13 @@ enum {idprefix}version {{
match typ:
case idl.Number():
- ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
- prefix = f"{idprefix.upper()}{typ.typname.upper()}_"
+ ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"
+ prefix = f"{c9util.IDENT(typ.typname)}_"
namewidth = max(len(name) for name in typ.vals)
for name, val in typ.vals.items():
- ret += f"#define {prefix}{name:<{namewidth}} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
+ ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
case idl.Bitfield():
- ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
+ ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"
def bitname(val: idl.Bit | idl.BitAlias) -> str:
s = val.bitname
@@ -456,7 +392,7 @@ enum {idprefix}version {{
s = f"_{s}_{n}"
case idl.Bit(cat=idl.BitCat.UNUSED):
return ""
- return add_prefix(f"{idprefix.upper()}{typ.typname.upper()}_", s)
+ return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s))
namewidth = max(
len(bitname(val)) for val in [*typ.bits, *typ.names.values()]
@@ -467,7 +403,7 @@ enum {idprefix}version {{
vers = bit.in_versions
if bit.cat == idl.BitCat.UNUSED:
vers = typ.in_versions
- ret += cutil.ifdef_push(2, c_ver_ifdef(vers))
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers))
# It is important all of the `beg` strings have
# the same length.
@@ -486,7 +422,7 @@ enum {idprefix}version {{
c_name = bitname(bit)
c_val = f"1<<{bit.num}"
- ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n"
+ ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n"
if aliases := [
alias
for alias in typ.names.values()
@@ -495,7 +431,7 @@ enum {idprefix}version {{
ret += "\n"
for alias in aliases:
- ret += cutil.ifdef_push(2, c_ver_ifdef(alias.in_versions))
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions))
end = ""
if cutil.ifdef_leaf_is_noop():
@@ -505,23 +441,23 @@ enum {idprefix}version {{
c_name = bitname(alias)
c_val = alias.val
- ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n"
+ ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n"
ret += cutil.ifdef_pop(1)
del bitname
case idl.Struct(): # and idl.Message():
- ret += c_typename(typ) + " {"
+ ret += c9util.typename(typ) + " {"
if not typ.members:
ret += "};\n"
continue
ret += "\n"
- typewidth = max(len(c_typename(m.typ, m)) for m in typ.members)
+ typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members)
for member in typ.members:
if member.val:
continue
- ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n"
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n"
ret += cutil.ifdef_pop(1)
ret += "};\n"
del typ
@@ -531,7 +467,7 @@ enum {idprefix}version {{
/* containers *****************************************************************/
"""
ret += "\n"
- ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n"
+ ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n"
tmsg_max_iov: dict[str, int] = {}
tmsg_max_copy: dict[str, int] = {}
@@ -570,7 +506,7 @@ enum {idprefix}version {{
directive = "if"
seen_e = False # SPECIAL (9P2000.e)
for maxval in sorted(inv, reverse=True):
- ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n"
+ ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n"
indent = 1
if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
typ = next(typ for typ in typs if typ.typname == "Tswrite")
@@ -582,11 +518,11 @@ enum {idprefix}version {{
maxexpr = f"{sz.max_copy}{sz.max_copy_extra}"
case _:
assert False
- ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n"
- ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n"
+ ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n"
+ ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n"
ret += "\t#else\n"
indent += 1
- ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n"
+ ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n"
if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
ret += "\t#endif\n"
if "9P2000.e" in inv[maxval]:
@@ -595,17 +531,17 @@ enum {idprefix}version {{
ret += "#endif\n"
ret += "\n"
- ret += f"struct {idprefix}Tmsg_send_buf {{\n"
+ ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n"
ret += "\tsize_t iov_cnt;\n"
- ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n"
- ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n"
+ ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n"
+ ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n"
ret += "};\n"
ret += "\n"
- ret += f"struct {idprefix}Rmsg_send_buf {{\n"
+ ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n"
ret += "\tsize_t iov_cnt;\n"
- ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n"
- ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n"
+ ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n"
+ ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n"
ret += "};\n"
return ret
@@ -647,11 +583,11 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
id2typ[msg.msgid] = msg
def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
- ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
+ ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
for ver in ["unknown", *sorted(versions)]:
if ver != "unknown":
- ret += cutil.ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t[{c_ver_enum(ver)}] = {{\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f"\t[{c9util.ver_enum(ver)}] = {{\n"
for n in range(*rng):
xmsg: idl.Message | None = id2typ.get(n, None)
if xmsg:
@@ -670,14 +606,16 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
for v in sorted(versions):
ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
- ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n"
+ ret += (
+ f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n"
+ )
ret += "#else\n"
ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
ret += "#endif\n"
ret += "\n"
ret += "/**\n"
- ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n"
- ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n"
+ ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n"
+ ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n"
ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n"
ret += " * several version checks together.\n"
ret += " */\n"
@@ -687,17 +625,17 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
ret += f"""
/* strings ********************************************************************/
-const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{
+const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{
"""
for ver in ["unknown", *sorted(versions)]:
if ver in versions:
- ret += cutil.ifdef_push(1, c_ver_ifdef({ver}))
- ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n'
ret += cutil.ifdef_pop(0)
ret += "};\n"
ret += "\n"
- ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
+ ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n"
ret += msg_table("msg", "name", "char *const", (0, 0x100, 1))
# bitmasks #################################################################
@@ -708,13 +646,13 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{
if not isinstance(typ, idl.Bitfield):
continue
ret += "\n"
- ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n"
verwidth = max(len(ver) for ver in versions)
for ver in sorted(versions):
- ret += cutil.ifdef_push(2, c_ver_ifdef({ver}))
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver}))
ret += (
- f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
+ f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
+ "".join(
(
"1"
@@ -778,7 +716,7 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
- ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions))
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
match typ:
@@ -787,11 +725,11 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
case idl.Bitfield():
ret += f"\t if (validate_{typ.static_size}(ctx))\n"
ret += "\t\treturn true;\n"
- ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
+ ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
if typ.static_size == 1:
- ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
+ ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
else:
- ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
ret += "\tif (val & ~mask)\n"
ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
@@ -804,8 +742,8 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
# Pass 1 - declare value variables
for member in typ.members:
if should_save_value(typ, member):
- ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ)} {member.membname};\n"
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
ret += cutil.ifdef_pop(1)
# Pass 2 - declare offset variables
@@ -820,15 +758,15 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
# Pass 3 - main pass
ret += "\treturn false\n"
for member in typ.members:
- ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions))
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
ret += "\t || "
if member.in_versions != typ.in_versions:
- ret += "( " + c_ver_cond(member.in_versions) + " && "
+ ret += "( " + c9util.ver_cond(member.in_versions) + " && "
if member.cnt is not None:
if member.typ.static_size == 1: # SPECIAL (zerocopy)
ret += f"_validate_size_net(ctx, {member.cnt.membname})"
else:
- ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))"
+ ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
if typ.typname == "s": # SPECIAL (string)
ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
else:
@@ -864,14 +802,14 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
if member.max:
assert member.static_size
nbits = member.static_size * 8
- ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
if member.val:
assert member.static_size
nbits = member.static_size * 8
- ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
ret += cutil.ifdef_pop(1)
@@ -907,18 +845,18 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
- ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n"
match typ:
case idl.Number():
- ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n"
+ ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n"
case idl.Bitfield():
- ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n"
+ ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n"
case idl.Struct():
ret += "\tmemset(out, 0, sizeof(*out));\n"
for member in typ.members:
- ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions))
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
if member.val:
ret += f"\tctx->net_offset += {member.static_size};\n"
continue
@@ -926,7 +864,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
prefix = "\t"
if member.in_versions != typ.in_versions:
- ret += "if ( " + c_ver_cond(member.in_versions) + " ) "
+ ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) "
prefix = "\t\t"
if member.cnt:
if member.in_versions != typ.in_versions:
@@ -1041,14 +979,14 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
oneline.append(f"({cnt.c_str(root)})*{sub.static}")
continue
loopvar = chr(ord("i") + loop_depth)
- multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
+ multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
multiline += sub.gen_c(
"", dstvar, root, indent_depth + 1, loop_depth + 1
)
multiline += f"{'\t'*indent_depth}}}\n"
for vers, sub in self.cond.items():
- multiline += cutil.ifdef_push(indent_depth + 1, c_ver_ifdef(vers))
- multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n"
+ multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
+ multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
multiline += f"{'\t'*indent_depth}}}\n"
multiline += cutil.ifdef_pop(indent_depth)
@@ -1143,8 +1081,8 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
continue
assert isinstance(typ, idl.Struct)
ret += "\n"
- ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n"
# Pass 1 - check size
max_size = max(typ.max_size(v) for v in typ.in_versions)
@@ -1235,10 +1173,10 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
if child.in_versions < parent.in_versions:
ret += cutil.ifdef_push(
- ifdef_depth + 1, c_ver_ifdef(child.in_versions)
+ ifdef_depth + 1, c9util.ver_ifdef(child.in_versions)
)
ifdef_depth += 1
- ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n"
+ ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n"
stack.append((path, True))
if child.cnt:
cnt_path = path.parent().add(child.cnt)
@@ -1249,7 +1187,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
return WalkCmd.KEEP_GOING, pop
loopvar = chr(ord("i") + loopdepth - 1)
- ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
+ ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
stack.append((path, False))
if not isinstance(child.typ, idl.Struct):
if child.val:
@@ -1264,7 +1202,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
+ sym
)
- val = c_expr(child.val, lookup_sym)
+ val = c9util.idl_expr(child.val, lookup_sym)
else:
val = path.c_str("val->")
if isinstance(child.typ, idl.Bitfield):
@@ -1287,45 +1225,53 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
"""
ret += "\n"
- ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n"
+ ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n"
rerror = next(typ for typ in typs if typ.typname == "Rerror")
- ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization)
+ ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization)
for ver in sorted(versions):
- ret += cutil.ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n"
ret += cutil.ifdef_pop(0)
ret += "};\n"
ret += "\n"
ret += cutil.macro(
- f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
- f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n"
+ f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n"
+ f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n"
f"\t\t.validate = validate_##typ,\n"
f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n"
f"\t}}\n"
)
ret += cutil.macro(
- f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
+ f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n"
f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n"
f"\t}}\n"
)
ret += "\n"
- ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2))
+ ret += msg_table(
+ "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2)
+ )
ret += "\n"
- ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2))
+ ret += msg_table(
+ "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2)
+ )
ret += "\n"
- ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2))
+ ret += msg_table(
+ "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2)
+ )
ret += "\n"
- ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2))
+ ret += msg_table(
+ "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2)
+ )
ret += f"""
-LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{
+LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{
\treturn validate_stat(ctx);
}}
-LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{
+LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{
\tunmarshal_stat(ctx, out);
}}
-LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{
+LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{
\treturn marshal_stat(ctx, val);
}}
"""
diff --git a/lib9p/protogen/c9util.py b/lib9p/protogen/c9util.py
new file mode 100644
index 0000000..e7ad999
--- /dev/null
+++ b/lib9p/protogen/c9util.py
@@ -0,0 +1,109 @@
+# lib9p/protogen/c9util.py - Utilities for generating lib9p-specific C
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import typing
+
+import idl
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+# pylint: disable=unused-variable
+__all__ = [
+ "add_prefix",
+ "ident",
+ "Ident",
+ "IDENT",
+ "ver_enum",
+ "ver_ifdef",
+ "ver_cond",
+ "typename",
+ "idl_expr",
+]
+
+# idents #######################################################################
+
+
+def add_prefix(p: str, s: str) -> str:
+ if s.startswith("_"):
+ return "_" + p + s[1:]
+ return p + s
+
+
+def _ident(p: str, s: str) -> str:
+ return add_prefix(p, s.replace(".", "_"))
+
+
+def ident(s: str) -> str:
+ return _ident("lib9p_", s)
+
+
+def Ident(s: str) -> str:
+ return _ident("lib9p_".upper(), s)
+
+
+def IDENT(s: str) -> str:
+ return _ident("lib9p_", s).upper()
+
+
+# versions #####################################################################
+
+
+def ver_enum(ver: str) -> str:
+ return Ident("VER_" + ver)
+
+
+def ver_ifdef(versions: typing.Collection[str]) -> str:
+ return " || ".join(
+ f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
+ )
+
+
+def ver_cond(versions: typing.Collection[str]) -> str:
+ if len(versions) == 1:
+ v = next(v for v in versions)
+ return f"is_ver(ctx, {v.replace('.', '_')})"
+ return "( " + (" || ".join(ver_cond({v}) for v in sorted(versions))) + " )"
+
+
+# misc #########################################################################
+
+
+def typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
+ match typ:
+ case idl.Primitive():
+ if typ.value == 1 and parent and parent.cnt: # SPECIAL (string)
+ return "[[gnu::nonstring]] char"
+ return f"uint{typ.value*8}_t"
+ case idl.Number():
+ return ident(f"{typ.typname}_t")
+ case idl.Bitfield():
+ return ident(f"{typ.typname}_t")
+ case idl.Message():
+ return f"struct {ident(f'msg_{typ.typname}')}"
+ case idl.Struct():
+ return f"struct {ident(typ.typname)}"
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
+def idl_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
+ ret: list[str] = []
+ for tok in expr.tokens:
+ match tok:
+ case idl.ExprOp():
+ ret.append(tok.op)
+ case idl.ExprLit():
+ ret.append(str(tok.val))
+ case idl.ExprSym(symname="s32_max"):
+ ret.append("INT32_MAX")
+ case idl.ExprSym(symname="s64_max"):
+ ret.append("INT64_MAX")
+ case idl.ExprSym():
+ ret.append(lookup_sym(tok.symname))
+ case _:
+ assert False
+ return " ".join(ret)