summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen74
1 files changed, 26 insertions, 48 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index 558592f..8782771 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -66,12 +66,10 @@ def c_ver_cond(versions: set[str]) -> str:
return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-def c_typename(typ: idl.Type, parent: idl.Type | None = None) -> str:
+def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
match typ:
case idl.Primitive():
- if (
- typ.value == 1 and parent and parent.name in ["d", "d_signed", "s"]
- ): # SPECIAL
+ if typ.value == 1 and parent and parent.cnt: # SPECIAL (string)
return "[[gnu::nonstring]] char"
return f"uint{typ.value*8}_t"
case idl.Number():
@@ -322,13 +320,13 @@ enum {idprefix}version {{
continue
ret += "\n"
- typewidth = max(len(c_typename(m.typ, typ)) for m in typ.members)
+ typewidth = max(len(c_typename(m.typ, m)) for m in typ.members)
for member in typ.members:
if member.val:
continue
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ, typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
ret += ifdef_pop(1)
ret += "};\n"
ret += ifdef_pop(0)
@@ -381,7 +379,7 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
for n in range(*rng):
xmsg: idl.Message | None = id2typ.get(n, None)
if xmsg:
- if ver == "unknown": # SPECIAL
+ if ver == "unknown": # SPECIAL (initialization)
if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
xmsg = None
else:
@@ -484,10 +482,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
\treturn false;
}
-#define validate_1(ctx) _validate_size_net(ctx, 1)
-#define validate_2(ctx) _validate_size_net(ctx, 2)
-#define validate_4(ctx) _validate_size_net(ctx, 4)
-#define validate_8(ctx) _validate_size_net(ctx, 8)
+LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
+LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
+LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
+LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""
for typ in topo_sorted(typs):
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
@@ -496,31 +494,6 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"
- if typ.name == "d" or typ.name == "d_signed": # SPECIAL
- # Optimize... maybe the compiler could figure out to do
- # this, but let's make it obvious.
- ret += "\tuint32_t base_offset = ctx->net_offset;\n"
- ret += "\tif (validate_4(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += "\tuint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]);\n"
- ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
- ret += "}\n"
- continue
- if typ.name == "s": # SPECIAL
- # Add an extra nul-byte on the host, and validate UTF-8
- # (also, similar optimization to "d").
- ret += "\tuint32_t base_offset = ctx->net_offset;\n"
- ret += "\tif (validate_2(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += "\tuint16_t len = uint16le_decode(&ctx->net_bytes[base_offset]);\n"
- ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)))\n"
- ret += "\t\treturn true;\n"
- ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
- ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
- ret += "\treturn false;\n"
- ret += "}\n"
- continue
-
match typ:
case idl.Number():
ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
@@ -543,9 +516,18 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += "}\n"
continue
+ def should_save_value(member: idl.StructMember) -> bool:
+ nonlocal typ
+ assert isinstance(typ, idl.Struct)
+ return bool(
+ member.max
+ or member.val
+ or any(m.cnt == member for m in typ.members)
+ )
+
# Pass 1 - declare value variables
for member in typ.members:
- if member.max or member.val:
+ if should_save_value(member):
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
ret += f"\t{c_typename(member.typ)} {member.name};\n"
ret += ifdef_pop(1)
@@ -561,27 +543,24 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
# Pass 3 - main pass
ret += "\treturn false\n"
- prev_size: int | None = None
for member in typ.members:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
ret += f"\t || "
if member.in_versions != typ.in_versions:
ret += "( " + c_ver_cond(member.in_versions) + " && "
if member.cnt is not None:
- assert prev_size
- if prev_size == 1:
- ret += f"_validate_list(ctx, ctx->net_bytes[ctx->net_offset-1], validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
- else:
- ret += f"_validate_list(ctx, uint{prev_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
+ ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
+ if typ.name == "s": # SPECIAL (string)
+ ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})'
else:
- if member.max or member.val:
+ if should_save_value(member):
ret += "("
if member.name in mark_offset:
ret += f"({{ _{member.name}_offset = ctx->net_offset; "
ret += f"validate_{member.typ.name}(ctx)"
if member.name in mark_offset:
ret += "; })"
- if member.max or member.val:
+ if should_save_value(member):
nbytes = member.static_size
assert nbytes
if nbytes == 1:
@@ -591,7 +570,6 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
if member.in_versions != typ.in_versions:
ret += " )"
ret += "\n"
- prev_size = member.static_size
# Pass 4 - validate ,max= and ,val= constraints
for member in typ.members:
@@ -669,7 +647,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
ret += f"out->{member.name} = ctx->extra;\n"
ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n"
ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n"
- if typ.name in ["d", "d_signed", "s"]: # SPECIAL
+ if member.typ.static_size == 1: # SPECIAL (string)
# Special-case is that we cast from `char` to `uint8_t`.
ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
else:
@@ -772,7 +750,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
ret += "({ bool err = false;\n"
ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n"
ret += "\t \terr = "
- if typ.name in ["d", "d_signed", "s"]: # SPECIAL
+ if member.typ.static_size == 1: # SPECIAL (string)
# Special-case is that we cast from `char` to `uint8_t`.
ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n"
else: