summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2025-01-19 15:53:46 -0700
committerLuke T. Shumaker <lukeshu@lukeshu.com>2025-01-19 15:53:46 -0700
commit104ea21b497171f5a1c4ba80d82337da3f7c2632 (patch)
tree9b5a167833b9caa4f8f829c9bc7a3711a1cd837a /lib9p/idl.gen
parenta35db3be439c9a27f0763036cf3d4992ccf893eb (diff)
parent0ab9da9bc3c6cdaef00b7202ba03eff917b44c95 (diff)
Merge branch 'lukeshu/9p-tidy'
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen219
1 files changed, 129 insertions, 90 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index 47ca49a..7c79a28 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -23,6 +23,9 @@ import idl
idprefix = "lib9p_"
+u32max = (1 << 32) - 1
+u64max = (1 << 64) - 1
+
def tab_ljust(s: str, width: int) -> str:
cur = len(s.expandtabs(tabsize=8))
@@ -67,9 +70,11 @@ def c_ver_cond(versions: set[str]) -> str:
return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-def c_typename(typ: idl.Type) -> str:
+def c_typename(typ: idl.Type, parent: idl.Type | None = None) -> str:
match typ:
case idl.Primitive():
+ if typ.value == 1 and parent and parent.name in ["d", "s"]: # SPECIAL
+ return "[[gnu::nonstring]] char"
return f"uint{typ.value*8}_t"
case idl.Number():
return f"{idprefix}{typ.name}_t"
@@ -88,17 +93,17 @@ def c_expr(expr: idl.Expr) -> str:
for tok in expr.tokens:
match tok:
case idl.ExprOp():
- ret += [tok.op]
+ ret.append(tok.op)
case idl.ExprLit():
- ret += [str(tok.val)]
+ ret.append(str(tok.val))
case idl.ExprSym(name="end"):
- ret += ["ctx->net_offset"]
+ ret.append("ctx->net_offset")
case idl.ExprSym(name="s32_max"):
- ret += ["INT32_MAX"]
+ ret.append("INT32_MAX")
case idl.ExprSym(name="s64_max"):
- ret += ["INT64_MAX"]
+ ret.append("INT64_MAX")
case idl.ExprSym():
- ret += [f"_{tok.name[1:]}_offset"]
+ ret.append(f"_{tok.name[1:]}_offset")
return " ".join(ret)
@@ -109,7 +114,7 @@ def ifdef_push(n: int, _newval: str) -> str:
# Grow the stack as needed
global _ifdef_stack
while len(_ifdef_stack) < n:
- _ifdef_stack += [None]
+ _ifdef_stack.append(None)
# Set some variables
parentval: str | None = None
@@ -204,8 +209,6 @@ enum {idprefix}version {{
ret += ifdef_pop(0)
ret += f"\t{c_ver_enum('NUM')},\n"
ret += "};\n"
- ret += "\n"
- ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"
ret += """
/* enum msg_type **************************************************************/
@@ -222,9 +225,43 @@ enum {idprefix}version {{
ret += """
/* payload types **************************************************************/
"""
+
+ def per_version_comment(
+ typ: idl.Type, fn: typing.Callable[[idl.Type, str], str]
+ ) -> str:
+ lines: dict[str, str] = {}
+ for version in sorted(typ.in_versions):
+ lines[version] = fn(typ, version)
+ if len(set(lines.values())) == 1:
+ for _, line in lines.items():
+ return f"/* {line} */\n"
+ assert False
+ else:
+ ret = ""
+ v_width = max(len(c_ver_enum(v)) for v in typ.in_versions)
+ for version, line in lines.items():
+ ret += f"/* {c_ver_enum(version).ljust(v_width)}: {line} */\n"
+ return ret
+
for typ in topo_sorted(typs):
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
+
+ def sum_size(typ: idl.Type, version: str) -> str:
+ min_size = typ.min_size(version)
+ max_size = typ.max_size(version)
+ assert min_size <= max_size and max_size < u64max
+ ret = ""
+ if min_size == max_size:
+ ret += f"size = {min_size:,}"
+ else:
+ ret += f"min_size = {min_size:,} ; max_size = {max_size:,}"
+ if max_size > u32max:
+ ret += " (warning: >UINT32_MAX)"
+ return ret
+
+ ret += per_version_comment(typ, sum_size)
+
match typ:
case idl.Number():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
@@ -281,16 +318,13 @@ enum {idprefix}version {{
continue
ret += "\n"
- typewidth = max(len(c_typename(m.typ)) for m in typ.members)
+ typewidth = max(len(c_typename(m.typ, typ)) for m in typ.members)
for member in typ.members:
if member.val:
continue
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- c_type = c_typename(member.typ)
- if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL
- c_type = "char"
- ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ ret += f"\t{c_typename(member.typ, typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
ret += ifdef_pop(1)
ret += "};\n"
ret += ifdef_pop(0)
@@ -330,6 +364,38 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
def unused(arg: str) -> str:
return f"LM_UNUSED({arg})"
+ for v in sorted(versions):
+ ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
+ ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n"
+ ret += "#else\n"
+ ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
+ ret += "#endif\n"
+ id2typ: dict[int, idl.Message] = {}
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
+ id2typ[msg.msgid] = msg
+
+ def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
+ ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
+ for ver in ["unknown", *sorted(versions)]:
+ if ver != "unknown":
+ ret += ifdef_push(1, c_ver_ifdef({ver}))
+ ret += f"\t[{c_ver_enum(ver)}] = {{\n"
+ for n in range(*rng):
+ xmsg: idl.Message | None = id2typ.get(n, None)
+ if xmsg:
+ if ver == "unknown": # SPECIAL
+ if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
+ xmsg = None
+ else:
+ if ver not in xmsg.in_versions:
+ xmsg = None
+ if xmsg:
+ ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n"
+ ret += "\t},\n"
+ ret += ifdef_pop(0)
+ ret += "};\n"
+ return ret
+
ret += "\n"
ret += "/**\n"
ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n"
@@ -338,18 +404,12 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
ret += " * several version checks together.\n"
ret += " */\n"
ret += "#define is_ver(ctx, ver) _is_ver_##ver(ctx->ctx->version)\n"
- for v in sorted(versions):
- ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
- ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n"
- ret += "#else\n"
- ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
- ret += "#endif\n"
# strings ##################################################################
ret += f"""
/* strings ********************************************************************/
-static const char *version_strs[{c_ver_enum('NUM')}] = {{
+const char *_lib9p_table_ver_name[{c_ver_enum('NUM')}] = {{
"""
for ver in ["unknown", *sorted(versions)]:
if ver in versions:
@@ -357,15 +417,10 @@ static const char *version_strs[{c_ver_enum('NUM')}] = {{
ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
ret += ifdef_pop(0)
ret += "};\n"
- ret += f"""
-const char *{idprefix}version_str(enum {idprefix}version ver) {{
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wtype-limits"
- assert(0 <= ver && ver < {c_ver_enum('NUM')});
-#pragma GCC diagnostic pop
- return version_strs[ver];
-}}
-"""
+
+ ret += "\n"
+ ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
+ ret += msg_table("msg", "name", "char *", (0, 0x100, 1))
# bitmasks #################################################################
ret += f"""
@@ -443,7 +498,7 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += "\tuint32_t base_offset = ctx->net_offset;\n"
ret += "\tif (validate_4(ctx))\n"
ret += "\t\treturn true;\n"
- ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
+ ret += "\tuint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]);\n"
ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
ret += "}\n"
continue
@@ -453,8 +508,8 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += "\tuint32_t base_offset = ctx->net_offset;\n"
ret += "\tif (validate_2(ctx))\n"
ret += "\t\treturn true;\n"
- ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
- ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
+ ret += "\tuint16_t len = uint16le_decode(&ctx->net_bytes[base_offset]);\n"
+ ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)))\n"
ret += "\t\treturn true;\n"
ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
@@ -471,7 +526,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += (
f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
)
- ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ if typ.static_size == 1:
+ ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
+ else:
+ ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
ret += f"\tif (val & ~mask)\n"
ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
@@ -507,7 +565,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += "( " + c_ver_cond(member.in_versions) + " && "
if member.cnt is not None:
assert prev_size
- ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
+ if prev_size == 1:
+ ret += f"_validate_list(ctx, ctx->net_bytes[ctx->net_offset-1], validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
+ else:
+ ret += f"_validate_list(ctx, uint{prev_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
else:
if member.max or member.val:
ret += "("
@@ -517,10 +578,12 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
if member.name in mark_offset:
ret += "; })"
if member.max or member.val:
- bytes = member.static_size
- assert bytes
- bits = bytes * 8
- ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))"
+ nbytes = member.static_size
+ assert nbytes
+ if nbytes == 1:
+ ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
+ else:
+ ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
if member.in_versions != typ.in_versions:
ret += " )"
ret += "\n"
@@ -551,22 +614,22 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
/* unmarshal_* ****************************************************************/
LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
-\t*out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
+\t*out = ctx->net_bytes[ctx->net_offset];
\tctx->net_offset += 1;
}
LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
-\t*out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
+\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 2;
}
LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
-\t*out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
+\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 4;
}
LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
-\t*out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
+\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 8;
}
"""
@@ -600,9 +663,10 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
ret += "{\n"
ret += prefix
ret += f"out->{member.name} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n"
if typ.name in ["d", "s"]: # SPECIAL
+ # Special-case is that we cast from `char` to `uint8_t`.
ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
else:
ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
@@ -612,9 +676,6 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
ret += (
f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
)
- if typ.name == "s": # SPECIAL
- ret += "\tctx->extra++;\n"
- ret += "\tout->utf8[out->len] = '\\0';\n"
ret += ifdef_pop(1)
ret += "}\n"
ret += ifdef_pop(0)
@@ -642,7 +703,7 @@ LM_ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
\tif (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
\t\treturn _marshal_too_large(ctx);
-\tencode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
+\tuint16le_encode(&ctx->net_bytes[ctx->net_offset], *val);
\tctx->net_offset += 2;
\treturn false;
}
@@ -650,7 +711,7 @@ LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val)
LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
\tif (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
\t\treturn true;
-\tencode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
+\tuint32le_encode(&ctx->net_bytes[ctx->net_offset], *val);
\tctx->net_offset += 4;
\treturn false;
}
@@ -658,7 +719,7 @@ LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val)
LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
\tif (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
\t\treturn true;
-\tencode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
+\tuint64le_encode(&ctx->net_bytes[ctx->net_offset], *val);
\tctx->net_offset += 8;
\treturn false;
}
@@ -705,7 +766,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
ret += f"({{ _{member.name}_offset = ctx->net_offset; "
if member.cnt:
ret += "({ bool err = false;\n"
- ret += f"\t for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n"
+ ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n"
ret += "\t \terr = "
if typ.name in ["d", "s"]: # SPECIAL
# Special-case is that we cast from `char` to `uint8_t`.
@@ -732,22 +793,21 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
if member.val:
assert member.static_size
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n"
+ if member.static_size == 1:
+ ret += f"\t || ({{ ctx->net_bytes[_{member.name}_offset] = {c_expr(member.val)}; false; }})\n"
+ else:
+ ret += f"\t || ({{ uint{member.static_size*8}le_encode(&ctx->net_bytes[_{member.name}_offset], {c_expr(member.val)}); false; }})\n"
ret += ifdef_pop(1)
ret += "\t ;\n"
ret += "}\n"
ret += ifdef_pop(0)
- # tables / exports #########################################################
+ # function tables ##########################################################
ret += """
-/* tables / exports ***********************************************************/
-"""
- id2typ: dict[int, idl.Message] = {}
- for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
- id2typ[msg.msgid] = msg
+/* function tables ************************************************************/
- ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
+"""
ret += c_macro(
f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{",
f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),",
@@ -760,35 +820,14 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,",
f"\t}}",
)
-
- tables = [
- ("msg", "name", "char *", (0, 0x100, 1)),
- ("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)),
- ("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)),
- ("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)),
- ("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)),
- ]
- for grp, meth, tentry, rng in tables:
- ret += "\n"
- ret += f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
- for ver in ["unknown", *sorted(versions)]:
- if ver != "unknown":
- ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t[{c_ver_enum(ver)}] = {{\n"
- for n in range(*rng):
- xmsg: idl.Message | None = id2typ.get(n, None)
- if xmsg:
- if ver == "unknown": # SPECIAL
- if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
- xmsg = None
- else:
- if ver not in xmsg.in_versions:
- xmsg = None
- if xmsg:
- ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n"
- ret += "\t},\n"
- ret += ifdef_pop(0)
- ret += "};\n"
+ ret += "\n"
+ ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2))
+ ret += "\n"
+ ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2))
+ ret += "\n"
+ ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2))
+ ret += "\n"
+ ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2))
ret += f"""
LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{