summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen272
1 files changed, 147 insertions, 125 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index db5f37d..779b6d5 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -12,8 +12,7 @@ import sys
import typing
sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))
-
-import idl
+import idl # pylint: disable=wrong-import-position,import-self
# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
@@ -74,13 +73,13 @@ def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
return "[[gnu::nonstring]] char"
return f"uint{typ.value*8}_t"
case idl.Number():
- return f"{idprefix}{typ.name}_t"
+ return f"{idprefix}{typ.typname}_t"
case idl.Bitfield():
- return f"{idprefix}{typ.name}_t"
+ return f"{idprefix}{typ.typname}_t"
case idl.Message():
- return f"struct {idprefix}msg_{typ.name}"
+ return f"struct {idprefix}msg_{typ.typname}"
case idl.Struct():
- return f"struct {idprefix}{typ.name}"
+ return f"struct {idprefix}{typ.typname}"
case _:
raise ValueError(f"not a type: {typ.__class__.__name__}")
@@ -93,12 +92,12 @@ def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
ret.append(tok.op)
case idl.ExprLit():
ret.append(str(tok.val))
- case idl.ExprSym(name="s32_max"):
+ case idl.ExprSym(symname="s32_max"):
ret.append("INT32_MAX")
- case idl.ExprSym(name="s64_max"):
+ case idl.ExprSym(symname="s64_max"):
ret.append("INT64_MAX")
case idl.ExprSym():
- ret.append(lookup_sym(tok.name))
+ ret.append(lookup_sym(tok.symname))
case _:
assert False
return " ".join(ret)
@@ -109,7 +108,6 @@ _ifdef_stack: list[str | None] = []
def ifdef_push(n: int, _newval: str) -> str:
# Grow the stack as needed
- global _ifdef_stack
while len(_ifdef_stack) < n:
_ifdef_stack.append(None)
@@ -191,14 +189,14 @@ class Path:
for i, elem in enumerate(self.elems):
if i > 0:
ret += "."
- ret += elem.name
+ ret += elem.membname
if elem.cnt:
ret += f"[{chr(ord('i')+loopdepth)}]"
loopdepth += 1
return ret
def __str__(self) -> str:
- return self.c_str(self.root.name + "->")
+ return self.c_str(self.root.typname + "->")
class WalkCmd(enum.Enum):
@@ -243,7 +241,7 @@ def walk(typ: idl.Type, handle: WalkHandler) -> None:
# get_buffer_size() ############################################################
-class BufferSize:
+class BufferSize(typing.NamedTuple):
min_size: int # really just here to sanity-check against typ.min_size(version)
exp_size: int # "expected" or max-reasonable size
max_size: int # really just here to sanity-check against typ.max_size(version)
@@ -251,8 +249,19 @@ class BufferSize:
max_copy_extra: str
max_iov: int
max_iov_extra: str
- _starts_with_copy: bool
- _ends_with_copy: bool
+
+
+class TmpBufferSize:
+ min_size: int
+ exp_size: int
+ max_size: int
+ max_copy: int
+ max_copy_extra: str
+ max_iov: int
+ max_iov_extra: str
+
+ tmp_starts_with_copy: bool
+ tmp_ends_with_copy: bool
def __init__(self) -> None:
self.min_size = 0
@@ -262,14 +271,14 @@ class BufferSize:
self.max_copy_extra = ""
self.max_iov = 0
self.max_iov_extra = ""
- self._starts_with_copy = False
- self._ends_with_copy = False
+ self.tmp_starts_with_copy = False
+ self.tmp_ends_with_copy = False
-def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
+def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize:
assert isinstance(typ, idl.Primitive) or (version in typ.in_versions)
- ret = BufferSize()
+ ret = TmpBufferSize()
if not isinstance(typ, idl.Struct):
assert typ.static_size
@@ -278,8 +287,8 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
ret.max_size = typ.static_size
ret.max_copy = typ.static_size
ret.max_iov = 1
- ret._starts_with_copy = True
- ret._ends_with_copy = True
+ ret.tmp_starts_with_copy = True
+ ret.tmp_ends_with_copy = True
return ret
def handle(path: Path) -> tuple[WalkCmd, None]:
@@ -292,20 +301,20 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
if child.typ.static_size == 1: # SPECIAL (zerocopy)
ret.max_iov += 1
# HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data
- ret.exp_size += 27 if child.name == "utf8" else 8192
+ ret.exp_size += 27 if child.membname == "utf8" else 8192
ret.max_size += child.max_cnt
- ret._ends_with_copy = False
+ ret.tmp_ends_with_copy = False
return WalkCmd.DONT_RECURSE, None
- sub = get_buffer_size(child.typ, version)
+ sub = _get_buffer_size(child.typ, version)
ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM
ret.max_size += sub.max_size * child.max_cnt
- if child.name == "wname" and path.root.name in (
+ if child.membname == "wname" and path.root.typname in (
"Tsread",
"Tswrite",
): # SPECIAL (9P2000.e)
- assert ret._ends_with_copy
- assert sub._starts_with_copy
- assert not sub._ends_with_copy
+ assert ret.tmp_ends_with_copy
+ assert sub.tmp_starts_with_copy
+ assert not sub.tmp_ends_with_copy
ret.max_copy_extra = (
f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})"
)
@@ -315,29 +324,29 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
ret.max_iov -= 1
else:
ret.max_copy += sub.max_copy * child.max_cnt
- if sub.max_iov == 1 and sub._starts_with_copy: # is purely copy
+ if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy
ret.max_iov += 1
else: # contains zero-copy segments
ret.max_iov += sub.max_iov * child.max_cnt
- if ret._ends_with_copy and sub._starts_with_copy:
+ if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy:
# we can merge this one
ret.max_iov -= 1
if (
- sub._ends_with_copy
- and sub._starts_with_copy
+ sub.tmp_ends_with_copy
+ and sub.tmp_starts_with_copy
and sub.max_iov > 1
):
# we can merge these
ret.max_iov -= child.max_cnt - 1
- ret._ends_with_copy = sub._ends_with_copy
+ ret.tmp_ends_with_copy = sub.tmp_ends_with_copy
return WalkCmd.DONT_RECURSE, None
- elif not isinstance(child.typ, idl.Struct):
+ if not isinstance(child.typ, idl.Struct):
assert child.typ.static_size
- if not ret._ends_with_copy:
+ if not ret.tmp_ends_with_copy:
if ret.max_size == 0:
- ret._starts_with_copy = True
+ ret.tmp_starts_with_copy = True
ret.max_iov += 1
- ret._ends_with_copy = True
+ ret.tmp_ends_with_copy = True
ret.min_size += child.typ.static_size
ret.exp_size += child.typ.static_size
ret.max_size += child.typ.static_size
@@ -350,6 +359,19 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
return ret
+def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
+ tmp = _get_buffer_size(typ, version)
+ return BufferSize(
+ min_size=tmp.min_size,
+ exp_size=tmp.exp_size,
+ max_size=tmp.max_size,
+ max_copy=tmp.max_copy,
+ max_copy_extra=tmp.max_copy_extra,
+ max_iov=tmp.max_iov,
+ max_iov_extra=tmp.max_iov_extra,
+ )
+
+
# Generate .h ##################################################################
@@ -372,7 +394,7 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
id2typ[msg.msgid] = msg
- ret += f"""
+ ret += """
/* config *********************************************************************/
#include "config.h"
@@ -384,7 +406,7 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
if ver == "9P2000.e": # SPECIAL (9P2000.e)
ret += "#else\n"
ret += f"\t#if {c_ver_ifdef({ver})}\n"
- ret += "\t\t#ifndef(CONFIG_9P_MAX_9P2000_e_WELEM)\n"
+ ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n"
ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
ret += "\t\t#endif\n"
ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
@@ -412,13 +434,15 @@ enum {idprefix}version {{
"""
ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
- namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message))
+ namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message))
for n in range(0x100):
if n not in id2typ:
continue
msg = id2typ[n]
ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
- ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
+ ret += (
+ f"\t{idprefix.upper()}TYP_{msg.typname.ljust(namewidth)} = {msg.msgid},\n"
+ )
ret += ifdef_pop(0)
ret += "};\n"
@@ -469,7 +493,7 @@ enum {idprefix}version {{
match typ:
case idl.Number():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
- prefix = f"{idprefix.upper()}{typ.name.upper()}_"
+ prefix = f"{idprefix.upper()}{typ.typname.upper()}_"
namewidth = max(len(name) for name in typ.vals)
for name, val in typ.vals.items():
ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
@@ -481,7 +505,7 @@ enum {idprefix}version {{
if aliases := [k for k in typ.names if k not in typ.bits]:
names.append("")
names.extend(aliases)
- prefix = f"{idprefix.upper()}{typ.name.upper()}_"
+ prefix = f"{idprefix.upper()}{typ.typname.upper()}_"
namewidth = max(len(add_prefix(prefix, name)) for name in names)
ret += "\n"
@@ -527,7 +551,7 @@ enum {idprefix}version {{
if member.val:
continue
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.membname};\n"
ret += ifdef_pop(1)
ret += "};\n"
ret += ifdef_pop(0)
@@ -545,7 +569,7 @@ enum {idprefix}version {{
for typ in typs:
if not isinstance(typ, idl.Message):
continue
- if typ.name in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e)
+ if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e)
continue
max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov
max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy
@@ -578,7 +602,7 @@ enum {idprefix}version {{
ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n"
indent = 1
if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
- typ = next(typ for typ in typs if typ.name == "Tswrite")
+ typ = next(typ for typ in typs if typ.typname == "Tswrite")
sz = get_buffer_size(typ, "9P2000.e")
match name:
case "tmsg_max_iov":
@@ -589,7 +613,7 @@ enum {idprefix}version {{
assert False
ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n"
ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n"
- ret += f"\t#else\n"
+ ret += "\t#else\n"
indent += 1
ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n"
if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
@@ -601,14 +625,14 @@ enum {idprefix}version {{
ret += "\n"
ret += f"struct {idprefix}Tmsg_send_buf {{\n"
- ret += f"\tsize_t iov_cnt;\n"
+ ret += "\tsize_t iov_cnt;\n"
ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n"
ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n"
ret += "};\n"
ret += "\n"
ret += f"struct {idprefix}Rmsg_send_buf {{\n"
- ret += f"\tsize_t iov_cnt;\n"
+ ret += "\tsize_t iov_cnt;\n"
ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n"
ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n"
ret += "};\n"
@@ -638,7 +662,7 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
"""
# utilities ################################################################
- ret += f"""
+ ret += """
/* utilities ******************************************************************/
"""
@@ -662,13 +686,13 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
xmsg: idl.Message | None = id2typ.get(n, None)
if xmsg:
if ver == "unknown": # SPECIAL (initialization)
- if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
+ if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]:
xmsg = None
else:
if ver not in xmsg.in_versions:
xmsg = None
if xmsg:
- ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n"
+ ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n"
ret += "\t},\n"
ret += ifdef_pop(0)
ret += "};\n"
@@ -707,7 +731,7 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{
ret += msg_table("msg", "name", "char *const", (0, 0x100, 1))
# bitmasks #################################################################
- ret += f"""
+ ret += """
/* bitmasks *******************************************************************/
"""
for typ in typs:
@@ -715,7 +739,7 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{
continue
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
+ ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n"
verwidth = max(len(ver) for ver in versions)
for ver in sorted(versions):
ret += ifdef_push(2, c_ver_ifdef({ver}))
@@ -769,28 +793,32 @@ LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _val
LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""
+
+ def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
+ return bool(
+ member.max or member.val or any(m.cnt == member for m in typ.members)
+ )
+
for typ in topo_sorted(typs):
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"
+ ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
match typ:
case idl.Number():
- ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
+ ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
case idl.Bitfield():
ret += f"\t if (validate_{typ.static_size}(ctx))\n"
ret += "\t\treturn true;\n"
- ret += (
- f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
- )
+ ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
if typ.static_size == 1:
ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
else:
ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
- ret += f"\tif (val & ~mask)\n"
- ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
+ ret += "\tif (val & ~mask)\n"
+ ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
case idl.Struct(): # and idl.Message()
if len(typ.members) == 0:
@@ -798,60 +826,51 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
ret += "}\n"
continue
- def should_save_value(member: idl.StructMember) -> bool:
- nonlocal typ
- assert isinstance(typ, idl.Struct)
- return bool(
- member.max
- or member.val
- or any(m.cnt == member for m in typ.members)
- )
-
# Pass 1 - declare value variables
for member in typ.members:
- if should_save_value(member):
+ if should_save_value(typ, member):
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ)} {member.name};\n"
+ ret += f"\t{c_typename(member.typ)} {member.membname};\n"
ret += ifdef_pop(1)
# Pass 2 - declare offset variables
mark_offset: set[str] = set()
for member in typ.members:
for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
- if tok.name[1:] not in mark_offset:
- ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
- mark_offset.add(tok.name[1:])
+ if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
+ if tok.symname[1:] not in mark_offset:
+ ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
+ mark_offset.add(tok.symname[1:])
# Pass 3 - main pass
ret += "\treturn false\n"
for member in typ.members:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || "
+ ret += "\t || "
if member.in_versions != typ.in_versions:
ret += "( " + c_ver_cond(member.in_versions) + " && "
if member.cnt is not None:
if member.typ.static_size == 1: # SPECIAL (zerocopy)
- ret += f"_validate_size_net(ctx, {member.cnt.name})"
+ ret += f"_validate_size_net(ctx, {member.cnt.membname})"
else:
- ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
- if typ.name == "s": # SPECIAL (string)
- ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})'
+ ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))"
+ if typ.typname == "s": # SPECIAL (string)
+ ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
else:
- if should_save_value(member):
+ if should_save_value(typ, member):
ret += "("
- if member.name in mark_offset:
- ret += f"({{ _{member.name}_offset = ctx->net_offset; "
- ret += f"validate_{member.typ.name}(ctx)"
- if member.name in mark_offset:
+ if member.membname in mark_offset:
+ ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
+ ret += f"validate_{member.typ.typname}(ctx)"
+ if member.membname in mark_offset:
ret += "; })"
- if should_save_value(member):
+ if should_save_value(typ, member):
nbytes = member.static_size
assert nbytes
if nbytes == 1:
- ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
+ ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
else:
- ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
+ ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
if member.in_versions != typ.in_versions:
ret += " )"
ret += "\n"
@@ -871,14 +890,14 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val
assert member.static_size
nbits = member.static_size * 8
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.name}) > max) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n'
+ ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
if member.val:
assert member.static_size
nbits = member.static_size * 8
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.name}) != exp) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n'
+ ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
ret += ifdef_pop(1)
ret += "\t ;\n"
@@ -914,12 +933,12 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
+ ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
match typ:
case idl.Number():
- ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
+ ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n"
case idl.Bitfield():
- ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
+ ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n"
case idl.Struct():
ret += "\tmemset(out, 0, sizeof(*out));\n"
@@ -939,21 +958,17 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
ret += "{\n"
ret += prefix
if member.typ.static_size == 1: # SPECIAL (string, zerocopy)
- ret += f"out->{member.name} = (char *)&ctx->net_bytes[ctx->net_offset];\n"
- ret += (
- f"{prefix}ctx->net_offset += out->{member.cnt.name};\n"
- )
+ ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n"
+ ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n"
else:
- ret += f"out->{member.name} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n"
- ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
+ ret += f"out->{member.membname} = ctx->extra;\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n"
+ ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n"
if member.in_versions != typ.in_versions:
ret += "\t}\n"
else:
- ret += (
- f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
- )
+ ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n"
ret += ifdef_pop(1)
ret += "}\n"
ret += ifdef_pop(0)
@@ -1140,7 +1155,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]:
def ret(path: Path) -> WalkCmd:
- if len(path.elems) == 1 and path.elems[0].name == name:
+ if len(path.elems) == 1 and path.elems[0].membname == name:
return WalkCmd.ABORT
return WalkCmd.KEEP_GOING
@@ -1148,13 +1163,13 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
for typ in typs:
if not (
- isinstance(typ, idl.Message) or typ.name == "stat"
+ isinstance(typ, idl.Message) or typ.typname == "stat"
): # SPECIAL (include stat)
continue
assert isinstance(typ, idl.Struct)
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static bool marshal_{typ.name}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n"
+ ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n"
# Pass 1 - check size
max_size = max(typ.max_size(v) for v in typ.in_versions)
@@ -1170,8 +1185,8 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
)
ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
- ret += f'\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%PRIu32)",\n'
- ret += f'\t\t\t"{typ.name}",\n'
+ ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
+ ret += f'\t\t\t"{typ.typname}",\n'
ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
ret += "\t\t\tctx->ctx->max_msg_size);\n"
ret += "\t\treturn true;\n"
@@ -1209,12 +1224,12 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
for tok in member.val.tokens:
if not isinstance(tok, idl.ExprSym):
continue
- if tok.name == "end" or tok.name.startswith("&"):
- if tok.name not in offsets:
- offsets.append(tok.name)
+ if tok.symname == "end" or tok.symname.startswith("&"):
+ if tok.symname not in offsets:
+ offsets.append(tok.symname)
for name in offsets:
name_prefix = "offsetof_" + "".join(
- m.name + "_" for m in path.elems
+ m.membname + "_" for m in path.elems
)
if name == "end":
if not path.elems:
@@ -1251,7 +1266,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
if child.cnt:
cnt_path = path.parent().add(child.cnt)
if child.typ.static_size == 1: # SPECIAL (zerocopy)
- if path.root.name == "stat": # SPECIAL (stat)
+ if path.root.typname == "stat": # SPECIAL (stat)
ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
else:
ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
@@ -1268,7 +1283,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
sym = sym[1:]
return (
"offsetof_"
- + "".join(m.name + "_" for m in path.elems[:-1])
+ + "".join(m.membname + "_" for m in path.elems[:-1])
+ sym
)
@@ -1276,11 +1291,14 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
else:
val = path.c_str("val->")
if isinstance(child.typ, idl.Bitfield):
- val += f" & {child.typ.name}_masks[ctx->ctx->version]"
+ val += f" & {child.typ.typname}_masks[ctx->ctx->version]"
ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
return WalkCmd.KEEP_GOING, pop
walk(typ, handle)
+ del handle
+ del stack
+ del max_size
ret += "\treturn false;\n"
ret += "}\n"
@@ -1293,7 +1311,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
ret += "\n"
ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n"
- rerror = next(typ for typ in typs if typ.name == "Rerror")
+ rerror = next(typ for typ in typs if typ.typname == "Rerror")
ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization)
for ver in sorted(versions):
ret += ifdef_push(1, c_ver_ifdef({ver}))
@@ -1342,9 +1360,7 @@ LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idpref
# Main #########################################################################
-if __name__ == "__main__":
- import sys
-
+def main() -> None:
if typing.TYPE_CHECKING:
class ANSIColors:
@@ -1375,7 +1391,13 @@ if __name__ == "__main__":
sys.exit(2)
versions, typs = parser.all()
outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
- with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
+ with open(
+ os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8"
+ ) as fh:
fh.write(gen_h(versions, typs))
- with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
+ with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh:
fh.write(gen_c(versions, typs))
+
+
+if __name__ == "__main__":
+ main()