summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen48
1 files changed, 40 insertions, 8 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index e796855..31f6527 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -5,8 +5,10 @@
# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
+import graphlib
import os.path
import sys
+import typing
sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))
@@ -90,6 +92,10 @@ def c_expr(expr: idl.Expr) -> str:
ret += [str(tok.val)]
case idl.ExprSym(name="end"):
ret += ["ctx->net_offset"]
+ case idl.ExprSym(name="s32_max"):
+ ret += ["INT32_MAX"]
+ case idl.ExprSym(name="s64_max"):
+ ret += ["INT64_MAX"]
case idl.ExprSym():
ret += [f"_{tok.name[1:]}_offset"]
return " ".join(ret)
@@ -137,6 +143,24 @@ def ifdef_pop(n: int) -> str:
return ret
+def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]:
+ ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter()
+ for typ in typs:
+ match typ:
+ case idl.Number():
+ ts.add(typ)
+ case idl.Bitfield():
+ ts.add(typ)
+ case idl.Struct(): # and idl.Message():
+ deps = [
+ member.typ
+ for member in typ.members
+ if not isinstance(member.typ, idl.Primitive)
+ ]
+ ts.add(typ, *deps)
+ return ts.static_order()
+
+
# Generate .h ##################################################################
@@ -197,12 +221,16 @@ enum {idprefix}version {{
ret += """
/* payload types **************************************************************/
"""
- for typ in typs:
+ for typ in topo_sorted(typs):
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
match typ:
case idl.Number():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
+ prefix = f"{idprefix.upper()}{typ.name.upper()}_"
+ namewidth = max(len(name) for name in typ.vals)
+ for name, val in typ.vals.items():
+ ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
case idl.Bitfield():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
names = [
@@ -381,7 +409,7 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
#define validate_4(ctx) _validate_size_net(ctx, 4)
#define validate_8(ctx) _validate_size_net(ctx, 8)
"""
- for typ in typs:
+ for typ in topo_sorted(typs):
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
@@ -480,13 +508,17 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
# Pass 4 - validate ,max= and ,val= constraints
for member in typ.members:
if member.max:
+ assert member.static_size
+ nbits = member.static_size * 8
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
+ ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max)}; (((uint{nbits}_t){member.name}) > max) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n'
if member.val:
+ assert member.static_size
+ nbits = member.static_size * 8
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'
+ ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val)}; (((uint{nbits}_t){member.name}) != exp) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n'
ret += ifdef_pop(1)
ret += "\t ;\n"
@@ -517,7 +549,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
\tctx->net_offset += 8;
}
"""
- for typ in typs:
+ for typ in topo_sorted(typs):
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
@@ -610,7 +642,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
\treturn false;
}
"""
- for typ in typs:
+ for typ in topo_sorted(typs):
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"