diff options
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 99 |
1 files changed, 50 insertions, 49 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index b23b5b8..8782771 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -66,12 +66,10 @@ def c_ver_cond(versions: set[str]) -> str: return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" -def c_typename(typ: idl.Type, parent: idl.Type | None = None) -> str: +def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: match typ: case idl.Primitive(): - if ( - typ.value == 1 and parent and parent.name in ["d", "d_signed", "s"] - ): # SPECIAL + if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" case idl.Number(): @@ -322,13 +320,13 @@ enum {idprefix}version {{ continue ret += "\n" - typewidth = max(len(c_typename(m.typ, typ)) for m in typ.members) + typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ, typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) @@ -381,7 +379,7 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: - if ver == "unknown": # SPECIAL + if ver == "unknown": # SPECIAL (initialization) if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: xmsg = None else: @@ -484,10 +482,10 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, \treturn false; } -#define validate_1(ctx) _validate_size_net(ctx, 1) -#define validate_2(ctx) _validate_size_net(ctx, 2) -#define validate_4(ctx) _validate_size_net(ctx, 4) -#define validate_8(ctx) _validate_size_net(ctx, 8) +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" @@ -496,31 +494,6 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" - if typ.name == "d" or typ.name == "d_signed": # SPECIAL - # Optimize... maybe the compiler could figure out to do - # this, but let's make it obvious. - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_4(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint32_t len = uint32le_decode(&ctx->net_bytes[base_offset]);\n" - ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" - ret += "}\n" - continue - if typ.name == "s": # SPECIAL - # Add an extra nul-byte on the host, and validate UTF-8 - # (also, similar optimization to "d"). - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_2(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint16_t len = uint16le_decode(&ctx->net_bytes[base_offset]);\n" - ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)))\n" - ret += "\t\treturn true;\n" - ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" - ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' - ret += "\treturn false;\n" - ret += "}\n" - continue - match typ: case idl.Number(): ret += f"\treturn validate_{typ.prim.name}(ctx);\n" @@ -543,9 +516,18 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, ret += "}\n" continue + def should_save_value(member: idl.StructMember) -> bool: + nonlocal typ + assert isinstance(typ, idl.Struct) + return bool( + member.max + or member.val + or any(m.cnt == member for m in typ.members) + ) + # Pass 1 - declare value variables for member in typ.members: - if member.max or member.val: + if should_save_value(member): ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ)} {member.name};\n" ret += ifdef_pop(1) @@ -561,27 +543,24 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, # Pass 3 - main pass ret += "\treturn false\n" - prev_size: int | None = None for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: - assert prev_size - if prev_size == 1: - ret += f"_validate_list(ctx, ctx->net_bytes[ctx->net_offset-1], validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" - else: - ret += f"_validate_list(ctx, uint{prev_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + if typ.name == "s": # SPECIAL (string) + ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' else: - if member.max or member.val: + if should_save_value(member): ret += "(" if member.name in mark_offset: ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" if member.name in mark_offset: ret += "; })" - if member.max or member.val: + if should_save_value(member): nbytes = member.static_size assert nbytes if nbytes == 1: @@ -591,7 +570,6 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, if member.in_versions != typ.in_versions: ret += " )" ret += "\n" - prev_size = member.static_size # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: @@ -669,7 +647,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += f"out->{member.name} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" - if typ.name in ["d", "d_signed", "s"]: # SPECIAL + if member.typ.static_size == 1: # SPECIAL (string) # Special-case is that we cast from `char` to `uint8_t`. ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" else: @@ -772,7 +750,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) ret += "({ bool err = false;\n" ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n" ret += "\t \terr = " - if typ.name in ["d", "d_signed", "s"]: # SPECIAL + if member.typ.static_size == 1: # SPECIAL (string) # Special-case is that we cast from `char` to `uint8_t`. ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n" else: @@ -855,11 +833,34 @@ LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct lib9p_s if __name__ == "__main__": import sys + if typing.TYPE_CHECKING: + + class ANSIColors: + MAGENTA = "\x1b[35m" + RED = "\x1b[31m" + RESET = "\x1b[0m" + + else: + from _colorize import ANSIColors # Present in Python 3.13+ + if len(sys.argv) < 2: raise ValueError("requires at least 1 .9p filename") parser = idl.Parser() for txtname in sys.argv[1:]: - parser.parse_file(txtname) + try: + parser.parse_file(txtname) + except SyntaxError as e: + print( + f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}", + file=sys.stderr, + ) + assert e.text + print(f"\t{e.text}", file=sys.stderr) + print( + f"\t{ANSIColors.RED}{'~'*len(e.text)}{ANSIColors.RESET}", + file=sys.stderr, + ) + sys.exit(2) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: |