diff options
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 272 |
1 files changed, 147 insertions, 125 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index db5f37d..779b6d5 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -12,8 +12,7 @@ import sys import typing sys.path.insert(0, os.path.normpath(os.path.join(__file__, ".."))) - -import idl +import idl # pylint: disable=wrong-import-position,import-self # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in @@ -74,13 +73,13 @@ def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" case idl.Number(): - return f"{idprefix}{typ.name}_t" + return f"{idprefix}{typ.typname}_t" case idl.Bitfield(): - return f"{idprefix}{typ.name}_t" + return f"{idprefix}{typ.typname}_t" case idl.Message(): - return f"struct {idprefix}msg_{typ.name}" + return f"struct {idprefix}msg_{typ.typname}" case idl.Struct(): - return f"struct {idprefix}{typ.name}" + return f"struct {idprefix}{typ.typname}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") @@ -93,12 +92,12 @@ def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: ret.append(tok.op) case idl.ExprLit(): ret.append(str(tok.val)) - case idl.ExprSym(name="s32_max"): + case idl.ExprSym(symname="s32_max"): ret.append("INT32_MAX") - case idl.ExprSym(name="s64_max"): + case idl.ExprSym(symname="s64_max"): ret.append("INT64_MAX") case idl.ExprSym(): - ret.append(lookup_sym(tok.name)) + ret.append(lookup_sym(tok.symname)) case _: assert False return " ".join(ret) @@ -109,7 +108,6 @@ _ifdef_stack: list[str | None] = [] def ifdef_push(n: int, _newval: str) -> str: # Grow the stack as needed - global _ifdef_stack while len(_ifdef_stack) < n: _ifdef_stack.append(None) @@ -191,14 +189,14 @@ class Path: for i, elem in enumerate(self.elems): if i > 0: ret += "." - ret += elem.name + ret += elem.membname if elem.cnt: ret += f"[{chr(ord('i')+loopdepth)}]" loopdepth += 1 return ret def __str__(self) -> str: - return self.c_str(self.root.name + "->") + return self.c_str(self.root.typname + "->") class WalkCmd(enum.Enum): @@ -243,7 +241,7 @@ def walk(typ: idl.Type, handle: WalkHandler) -> None: # get_buffer_size() ############################################################ -class BufferSize: +class BufferSize(typing.NamedTuple): min_size: int # really just here to sanity-check against typ.min_size(version) exp_size: int # "expected" or max-reasonable size max_size: int # really just here to sanity-check against typ.max_size(version) @@ -251,8 +249,19 @@ class BufferSize: max_copy_extra: str max_iov: int max_iov_extra: str - _starts_with_copy: bool - _ends_with_copy: bool + + +class TmpBufferSize: + min_size: int + exp_size: int + max_size: int + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + + tmp_starts_with_copy: bool + tmp_ends_with_copy: bool def __init__(self) -> None: self.min_size = 0 @@ -262,14 +271,14 @@ class BufferSize: self.max_copy_extra = "" self.max_iov = 0 self.max_iov_extra = "" - self._starts_with_copy = False - self._ends_with_copy = False + self.tmp_starts_with_copy = False + self.tmp_ends_with_copy = False -def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: +def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) - ret = BufferSize() + ret = TmpBufferSize() if not isinstance(typ, idl.Struct): assert typ.static_size @@ -278,8 +287,8 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: ret.max_size = typ.static_size ret.max_copy = typ.static_size ret.max_iov = 1 - ret._starts_with_copy = True - ret._ends_with_copy = True + ret.tmp_starts_with_copy = True + ret.tmp_ends_with_copy = True return ret def handle(path: Path) -> tuple[WalkCmd, None]: @@ -292,20 +301,20 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: if child.typ.static_size == 1: # SPECIAL (zerocopy) ret.max_iov += 1 # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data - ret.exp_size += 27 if child.name == "utf8" else 8192 + ret.exp_size += 27 if child.membname == "utf8" else 8192 ret.max_size += child.max_cnt - ret._ends_with_copy = False + ret.tmp_ends_with_copy = False return WalkCmd.DONT_RECURSE, None - sub = get_buffer_size(child.typ, version) + sub = _get_buffer_size(child.typ, version) ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM ret.max_size += sub.max_size * child.max_cnt - if child.name == "wname" and path.root.name in ( + if child.membname == "wname" and path.root.typname in ( "Tsread", "Tswrite", ): # SPECIAL (9P2000.e) - assert ret._ends_with_copy - assert sub._starts_with_copy - assert not sub._ends_with_copy + assert ret.tmp_ends_with_copy + assert sub.tmp_starts_with_copy + assert not sub.tmp_ends_with_copy ret.max_copy_extra = ( f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" ) @@ -315,29 +324,29 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: ret.max_iov -= 1 else: ret.max_copy += sub.max_copy * child.max_cnt - if sub.max_iov == 1 and sub._starts_with_copy: # is purely copy + if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy ret.max_iov += 1 else: # contains zero-copy segments ret.max_iov += sub.max_iov * child.max_cnt - if ret._ends_with_copy and sub._starts_with_copy: + if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: # we can merge this one ret.max_iov -= 1 if ( - sub._ends_with_copy - and sub._starts_with_copy + sub.tmp_ends_with_copy + and sub.tmp_starts_with_copy and sub.max_iov > 1 ): # we can merge these ret.max_iov -= child.max_cnt - 1 - ret._ends_with_copy = sub._ends_with_copy + ret.tmp_ends_with_copy = sub.tmp_ends_with_copy return WalkCmd.DONT_RECURSE, None - elif not isinstance(child.typ, idl.Struct): + if not isinstance(child.typ, idl.Struct): assert child.typ.static_size - if not ret._ends_with_copy: + if not ret.tmp_ends_with_copy: if ret.max_size == 0: - ret._starts_with_copy = True + ret.tmp_starts_with_copy = True ret.max_iov += 1 - ret._ends_with_copy = True + ret.tmp_ends_with_copy = True ret.min_size += child.typ.static_size ret.exp_size += child.typ.static_size ret.max_size += child.typ.static_size @@ -350,6 +359,19 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: return ret +def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: + tmp = _get_buffer_size(typ, version) + return BufferSize( + min_size=tmp.min_size, + exp_size=tmp.exp_size, + max_size=tmp.max_size, + max_copy=tmp.max_copy, + max_copy_extra=tmp.max_copy_extra, + max_iov=tmp.max_iov, + max_iov_extra=tmp.max_iov_extra, + ) + + # Generate .h ################################################################## @@ -372,7 +394,7 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str: for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg - ret += f""" + ret += """ /* config *********************************************************************/ #include "config.h" @@ -384,7 +406,7 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str: if ver == "9P2000.e": # SPECIAL (9P2000.e) ret += "#else\n" ret += f"\t#if {c_ver_ifdef({ver})}\n" - ret += "\t\t#ifndef(CONFIG_9P_MAX_9P2000_e_WELEM)\n" + ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += "\t\t#endif\n" ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" @@ -412,13 +434,15 @@ enum {idprefix}version {{ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" - namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message)) + namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" + ret += ( + f"\t{idprefix.upper()}TYP_{msg.typname.ljust(namewidth)} = {msg.msgid},\n" + ) ret += ifdef_pop(0) ret += "};\n" @@ -469,7 +493,7 @@ enum {idprefix}version {{ match typ: case idl.Number(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" - prefix = f"{idprefix.upper()}{typ.name.upper()}_" + prefix = f"{idprefix.upper()}{typ.typname.upper()}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" @@ -481,7 +505,7 @@ enum {idprefix}version {{ if aliases := [k for k in typ.names if k not in typ.bits]: names.append("") names.extend(aliases) - prefix = f"{idprefix.upper()}{typ.name.upper()}_" + prefix = f"{idprefix.upper()}{typ.typname.upper()}_" namewidth = max(len(add_prefix(prefix, name)) for name in names) ret += "\n" @@ -527,7 +551,7 @@ enum {idprefix}version {{ if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.membname};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) @@ -545,7 +569,7 @@ enum {idprefix}version {{ for typ in typs: if not isinstance(typ, idl.Message): continue - if typ.name in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) + if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) continue max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy @@ -578,7 +602,7 @@ enum {idprefix}version {{ ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" indent = 1 if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) - typ = next(typ for typ in typs if typ.name == "Tswrite") + typ = next(typ for typ in typs if typ.typname == "Tswrite") sz = get_buffer_size(typ, "9P2000.e") match name: case "tmsg_max_iov": @@ -589,7 +613,7 @@ enum {idprefix}version {{ assert False ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" - ret += f"\t#else\n" + ret += "\t#else\n" indent += 1 ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) @@ -601,14 +625,14 @@ enum {idprefix}version {{ ret += "\n" ret += f"struct {idprefix}Tmsg_send_buf {{\n" - ret += f"\tsize_t iov_cnt;\n" + ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" ret += "};\n" ret += "\n" ret += f"struct {idprefix}Rmsg_send_buf {{\n" - ret += f"\tsize_t iov_cnt;\n" + ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" ret += "};\n" @@ -638,7 +662,7 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: """ # utilities ################################################################ - ret += f""" + ret += """ /* utilities ******************************************************************/ """ @@ -662,13 +686,13 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: if ver == "unknown": # SPECIAL (initialization) - if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: + if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: xmsg = None else: if ver not in xmsg.in_versions: xmsg = None if xmsg: - ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n" + ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" ret += "\t},\n" ret += ifdef_pop(0) ret += "};\n" @@ -707,7 +731,7 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) # bitmasks ################################################################# - ret += f""" + ret += """ /* bitmasks *******************************************************************/ """ for typ in typs: @@ -715,7 +739,7 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ continue ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n" + ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): ret += ifdef_push(2, c_ver_ifdef({ver})) @@ -769,28 +793,32 @@ LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _val LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ + + def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: + return bool( + member.max or member.val or any(m.cnt == member for m in typ.members) + ) + for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" + ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: case idl.Number(): - ret += f"\treturn validate_{typ.prim.name}(ctx);\n" + ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" - ret += ( - f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n" - ) + ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" if typ.static_size == 1: ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" - ret += f"\tif (val & ~mask)\n" - ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' + ret += "\tif (val & ~mask)\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" case idl.Struct(): # and idl.Message() if len(typ.members) == 0: @@ -798,60 +826,51 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val ret += "}\n" continue - def should_save_value(member: idl.StructMember) -> bool: - nonlocal typ - assert isinstance(typ, idl.Struct) - return bool( - member.max - or member.val - or any(m.cnt == member for m in typ.members) - ) - # Pass 1 - declare value variables for member in typ.members: - if should_save_value(member): + if should_save_value(typ, member): ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ)} {member.name};\n" + ret += f"\t{c_typename(member.typ)} {member.membname};\n" ret += ifdef_pop(1) # Pass 2 - declare offset variables mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"): - if tok.name[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" - mark_offset.add(tok.name[1:]) + if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): + if tok.symname[1:] not in mark_offset: + ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" + mark_offset.add(tok.symname[1:]) # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || " + ret += "\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: if member.typ.static_size == 1: # SPECIAL (zerocopy) - ret += f"_validate_size_net(ctx, {member.cnt.name})" + ret += f"_validate_size_net(ctx, {member.cnt.membname})" else: - ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" - if typ.name == "s": # SPECIAL (string) - ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))" + if typ.typname == "s": # SPECIAL (string) + ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' else: - if should_save_value(member): + if should_save_value(typ, member): ret += "(" - if member.name in mark_offset: - ret += f"({{ _{member.name}_offset = ctx->net_offset; " - ret += f"validate_{member.typ.name}(ctx)" - if member.name in mark_offset: + if member.membname in mark_offset: + ret += f"({{ _{member.membname}_offset = ctx->net_offset; " + ret += f"validate_{member.typ.typname}(ctx)" + if member.membname in mark_offset: ret += "; })" - if should_save_value(member): + if should_save_value(typ, member): nbytes = member.static_size assert nbytes if nbytes == 1: - ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" else: - ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" + ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" @@ -871,14 +890,14 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val assert member.static_size nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.name}) > max) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n' + ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.name}) != exp) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n' + ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' ret += ifdef_pop(1) ret += "\t ;\n" @@ -914,12 +933,12 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" match typ: case idl.Number(): - ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" @@ -939,21 +958,17 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "{\n" ret += prefix if member.typ.static_size == 1: # SPECIAL (string, zerocopy) - ret += f"out->{member.name} = (char *)&ctx->net_bytes[ctx->net_offset];\n" - ret += ( - f"{prefix}ctx->net_offset += out->{member.cnt.name};\n" - ) + ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" + ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" else: - ret += f"out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" - ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + ret += f"out->{member.membname} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" if member.in_versions != typ.in_versions: ret += "\t}\n" else: - ret += ( - f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" - ) + ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" ret += ifdef_pop(1) ret += "}\n" ret += ifdef_pop(0) @@ -1140,7 +1155,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: def ret(path: Path) -> WalkCmd: - if len(path.elems) == 1 and path.elems[0].name == name: + if len(path.elems) == 1 and path.elems[0].membname == name: return WalkCmd.ABORT return WalkCmd.KEEP_GOING @@ -1148,13 +1163,13 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o for typ in typs: if not ( - isinstance(typ, idl.Message) or typ.name == "stat" + isinstance(typ, idl.Message) or typ.typname == "stat" ): # SPECIAL (include stat) continue assert isinstance(typ, idl.Struct) ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.name}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" + ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) @@ -1170,8 +1185,8 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ) ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" if isinstance(typ, idl.Message): # SPECIAL (disable for stat) - ret += f'\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%PRIu32)",\n' - ret += f'\t\t\t"{typ.name}",\n' + ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' + ret += f'\t\t\t"{typ.typname}",\n' ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' ret += "\t\t\tctx->ctx->max_msg_size);\n" ret += "\t\treturn true;\n" @@ -1209,12 +1224,12 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o for tok in member.val.tokens: if not isinstance(tok, idl.ExprSym): continue - if tok.name == "end" or tok.name.startswith("&"): - if tok.name not in offsets: - offsets.append(tok.name) + if tok.symname == "end" or tok.symname.startswith("&"): + if tok.symname not in offsets: + offsets.append(tok.symname) for name in offsets: name_prefix = "offsetof_" + "".join( - m.name + "_" for m in path.elems + m.membname + "_" for m in path.elems ) if name == "end": if not path.elems: @@ -1251,7 +1266,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o if child.cnt: cnt_path = path.parent().add(child.cnt) if child.typ.static_size == 1: # SPECIAL (zerocopy) - if path.root.name == "stat": # SPECIAL (stat) + if path.root.typname == "stat": # SPECIAL (stat) ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" else: ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" @@ -1268,7 +1283,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o sym = sym[1:] return ( "offsetof_" - + "".join(m.name + "_" for m in path.elems[:-1]) + + "".join(m.membname + "_" for m in path.elems[:-1]) + sym ) @@ -1276,11 +1291,14 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o else: val = path.c_str("val->") if isinstance(child.typ, idl.Bitfield): - val += f" & {child.typ.name}_masks[ctx->ctx->version]" + val += f" & {child.typ.typname}_masks[ctx->ctx->version]" ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" return WalkCmd.KEEP_GOING, pop walk(typ, handle) + del handle + del stack + del max_size ret += "\treturn false;\n" ret += "}\n" @@ -1293,7 +1311,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "\n" ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" - rerror = next(typ for typ in typs if typ.name == "Rerror") + rerror = next(typ for typ in typs if typ.typname == "Rerror") ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) for ver in sorted(versions): ret += ifdef_push(1, c_ver_ifdef({ver})) @@ -1342,9 +1360,7 @@ LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idpref # Main ######################################################################### -if __name__ == "__main__": - import sys - +def main() -> None: if typing.TYPE_CHECKING: class ANSIColors: @@ -1375,7 +1391,13 @@ if __name__ == "__main__": sys.exit(2) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) - with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: + with open( + os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8" + ) as fh: fh.write(gen_h(versions, typs)) - with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: + with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: fh.write(gen_c(versions, typs)) + + +if __name__ == "__main__": + main() |