diff options
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 589 |
1 files changed, 297 insertions, 292 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index 1f5e48c..8e7639a 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -17,6 +17,7 @@ from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast # Types ######################################################################## + class Primitive(enum.Enum): u8 = 1 u16 = 2 @@ -128,7 +129,7 @@ class StructMember: # from left-to-right when parsing cnt: str | None = None name: str - typ: 'Type' + typ: "Type" max: Expr val: Expr @@ -174,7 +175,7 @@ class Message(Struct): Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message -#type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13 +# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13 T = TypeVar("T", Number, Bitfield, Struct, Message) # Parse *.9p ################################################################### @@ -234,6 +235,8 @@ def parse_expr(expr: str) -> Expr: ret = Expr() for tok in re.split("([-+])", expr): if tok == "-" or tok == "+": + # I, for the life of me, do not understand why I need this + # cast() to keep mypy happy. ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))] elif re.fullmatch("[0-9]+", tok): ret.tokens += [ExprLit(int(tok))] @@ -451,8 +454,10 @@ def parse_file( # Generate C ################################################################### +idprefix = "lib9p_" + -def c_ver_enum(idprefix: str, ver: str) -> str: +def c_ver_enum(ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" @@ -462,15 +467,13 @@ def c_ver_ifdef(versions: set[str]) -> str: ) -def c_ver_cond(idprefix: str, versions: set[str]) -> str: +def c_ver_cond(versions: set[str]) -> str: if len(versions) == 1: - return f"(ctx->ctx->version=={c_ver_enum(idprefix, next(v for v in versions))})" - return ( - "( " + (" || ".join(c_ver_cond(idprefix, {v}) for v in sorted(versions))) + " )" - ) + return f"(ctx->ctx->version=={c_ver_enum(next(v for v in versions))})" + return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" -def c_typename(idprefix: str, typ: Type) -> str: +def c_typename(typ: Type) -> str: match typ: case Primitive(): return f"uint{typ.value*8}_t" @@ -517,6 +520,7 @@ def ifdef_push(n: int, _newval: str) -> str: ret += f"#if {newval}\n" return ret + def ifdef_pop(n: int) -> str: global _ifdef_stack ret = "" @@ -527,7 +531,7 @@ def ifdef_pop(n: int) -> str: return ret -def gen_h(idprefix: str, versions: set[str], typs: list[Type]) -> str: +def gen_h(versions: set[str], typs: list[Type]) -> str: global _ifdef_stack _ifdef_stack = [] @@ -550,10 +554,10 @@ enum {idprefix}version {{ for ver in fullversions: if ver in versions: ret += ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t{c_ver_enum(idprefix, ver)}," + ret += f"\t{c_ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += ifdef_pop(0) - ret += f"\t{c_ver_enum(idprefix, 'NUM')},\n" + ret += f"\t{c_ver_enum('NUM')},\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" @@ -566,9 +570,9 @@ enum {idprefix}version {{ ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) match typ: case Number(): - ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n" + ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" case Bitfield(): - ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n" + ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" names = [ *reversed( [typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))] @@ -584,28 +588,36 @@ enum {idprefix}version {{ ret += "\n" elif name.startswith(" "): ret += ifdef_push(2, c_ver_ifdef(typ.in_versions)) - sp = ' '*(len('# define ')+len(idprefix)+len(typ.name)+1+namewidth+2-len("/* unused")) - ret += f"/* unused{sp}(({c_typename(idprefix, typ)})(1<<{name[1:]})) */\n" + sp = " " * ( + len("# define ") + + len(idprefix) + + len(typ.name) + + 1 + + namewidth + + 2 + - len("/* unused") + ) + ret += f"/* unused{sp}(({c_typename(typ)})(1<<{name[1:]})) */\n" else: ret += ifdef_push(2, c_ver_ifdef(typ.names[name].in_versions)) if name.startswith("_"): c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}" else: c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}" - sp1 = ' ' if _ifdef_stack[-1] else '' - sp2 = ' ' if _ifdef_stack[-1] else ' ' - sp3 = ' '*(2+namewidth-len(name)) - ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(idprefix, typ)})({typ.names[name].val}))\n" + sp1 = " " if _ifdef_stack[-1] else "" + sp2 = " " if _ifdef_stack[-1] else " " + sp3 = " " * (2 + namewidth - len(name)) + ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(typ)})({typ.names[name].val}))\n" ret += ifdef_pop(1) case Struct(): - typewidth = max(len(c_typename(idprefix, m.typ)) for m in typ.members) + typewidth = max(len(c_typename(m.typ)) for m in typ.members) - ret += c_typename(idprefix, typ) + " {\n" + ret += c_typename(typ) + " {\n" for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - c_type = c_typename(idprefix, member.typ) + c_type = c_typename(member.typ) if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL c_type = "char" ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" @@ -630,19 +642,19 @@ enum {idprefix}version {{ for msg in [msg for msg in typs if isinstance(msg, Message)]: ret += "\n" ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += c_typename(idprefix, msg) + " {" + ret += c_typename(msg) + " {" if not msg.members: ret += "};\n" continue ret += "\n" - typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) + typewidth = max(len(c_typename(m.typ)) for m in msg.members) for member in msg.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) @@ -665,7 +677,7 @@ def c_expr(expr: Expr) -> str: return " ".join(ret) -def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str: +def gen_c(versions: set[str], typs: list[Type]) -> str: global _ifdef_stack _ifdef_stack = [] @@ -692,17 +704,17 @@ def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str: ret += f""" /* strings ********************************************************************/ -static const char *version_strs[{c_ver_enum(idprefix, 'NUM')}] = {{ +static const char *version_strs[{c_ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: ret += ifdef_push(1, c_ver_ifdef({ver})) - ret += f'\t[{c_ver_enum(idprefix, ver)}] = "{ver}",\n' + ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' ret += ifdef_pop(0) ret += "};\n" ret += f""" const char *{idprefix}version_str(enum {idprefix}version ver) {{ - assert(0 <= ver && ver < {c_ver_enum(idprefix, 'NUM')}); + assert(0 <= ver && ver < {c_ver_enum('NUM')}); return version_strs[ver]; }} @@ -764,6 +776,23 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + + if isinstance(typ, Bitfield): + ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n" + verwidth = max(len(ver) for ver in versions) + for ver in sorted(versions): + ret += ifdef_push(2, c_ver_ifdef({ver})) + ret += ( + f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + + "".join( + "1" if typ.bit_is_valid(bitname, ver) else "0" + for bitname in reversed(typ.bits) + ) + + ",\n" + ) + ret += ifdef_pop(1) + ret += "};\n" + ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" if typ.name == "d": # SPECIAL @@ -797,28 +826,12 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, case Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" - ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_ver_enum(idprefix, 'NUM')}] = {{\n" - verwidth = max(len(ver) for ver in versions) - for ver in sorted(versions): - ret += ifdef_push(2, c_ver_ifdef({ver})) - ret += ( - f"\t\t[{c_ver_enum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b" - + "".join( - "1" if typ.bit_is_valid(bitname, ver) else "0" - for bitname in reversed(typ.bits) - ) - + ",\n" - ) - ret += ifdef_pop(1) - ret += "\t};\n" ret += ( - f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n" + f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n" ) - ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += f"\tif (val & ~mask)\n" - ret += "\t\treturn lib9p_errorf(ctx->ctx,\n" - ret += f'\t\t LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n' - ret += "\t\t val & ~mask);\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" case Struct(): # and Message() if len(typ.members) == 0: @@ -826,14 +839,14 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += "}\n" continue - # Pass 1 + # Pass 1 - declare value variables for member in typ.members: if member.max or member.val: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(idprefix, member.typ)} {member.name};\n" + ret += f"\t{c_typename(member.typ)} {member.name};\n" ret += ifdef_pop(1) - # Pass 2 + # Pass 2 - declare offset variables mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: @@ -842,17 +855,17 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" mark_offset.add(tok.name[1:]) - # Pass 3 + # Pass 3 - main pass ret += "\treturn false\n" prev_size: int | None = None for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || " if member.in_versions != typ.in_versions: - ret += "( " + c_ver_cond(idprefix, member.in_versions) + " && " + ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: assert prev_size - ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" + ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" else: if member.max or member.val: ret += "(" @@ -871,252 +884,244 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += "\n" prev_size = member.static_size - # Pass 4 + # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: if member.max: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\n\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n" - ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n' + ret += f"\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n' if member.val: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\n\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n" - ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n' + ret += f"\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n' ret += ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += ifdef_pop(0) - # # unmarshal_* ############################################################## - # ret += """ - # /* unmarshal_* ****************************************************************/ - - # static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { - # *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 1; - # } - - # static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { - # *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 2; - # } - - # static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { - # *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 4; - # } - - # static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { - # *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 8; - # } - # """ - # for typ in typs: - # inline = ( - # " FLATTEN" - # if (isinstance(typ, Struct) and typ.msgid is not None) - # else " ALWAYS_INLINE" - # ) - # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used - # ret += "\n" - # ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" - # match typ: - # case Bitfield(): - # ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" - # case Struct(): - # ret += "\tmemset(out, 0, sizeof(*out));\n" - - # if typ.members: - # struct_versions = typ.members[0].ver - # for member in typ.members: - # if member.valexpr: - # ret += f"\tctx->net_offset += {member.static_size};\n" - # continue - # ret += "\t" - # prefix = "\t" - # if member.ver != struct_versions: - # ret += "if ( " + c_ver_cond(idprefix, member.ver) + " ) " - # prefix = "\t\t" - # if member.cnt: - # if member.ver != struct_versions: - # ret += f"{{\n{prefix}" - # ret += f"out->{member.name} = ctx->extra;\n" - # ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" - # ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" - # if typ.name in ["d", "s"]: # SPECIAL - # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" - # else: - # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" - # if member.ver != struct_versions: - # ret += "\t}\n" - # else: - # ret += ( - # f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" - # ) - # if typ.name == "s": # SPECIAL - # ret += "\tctx->extra++;\n" - # ret += "\tout->utf8[out->len] = '\\0';\n" - # ret += "}\n" - - # # marshal_* ################################################################ - # ret += """ - # /* marshal_* ******************************************************************/ - - # static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { - # lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", - # (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", - # ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), - # ctx->ctx->max_msg_size); - # return true; - # } - - # static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { - # if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) - # return _marshal_too_large(ctx); - # ctx->net_bytes[ctx->net_offset] = *val; - # ctx->net_offset += 1; - # return false; - # } - - # static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { - # if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) - # return _marshal_too_large(ctx); - # encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 2; - # return false; - # } - - # static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { - # if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) - # return true; - # encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 4; - # return false; - # } - - # static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { - # if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) - # return true; - # encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); - # ctx->net_offset += 8; - # return false; - # } - # """ - # for typ in typs: - # inline = ( - # " FLATTEN" - # if (isinstance(typ, Struct) and typ.msgid is not None) - # else " ALWAYS_INLINE" - # ) - # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used - # ret += "\n" - # ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" - # match typ: - # case Bitfield(): - # ret += "\n" - # ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" - # case Struct(): - # if len(typ.members) == 0: - # ret += "\n\treturn false;\n" - # ret += "}\n" - # continue - - # mark_offset = set() - # for member in typ.members: - # if member.valexpr: - # if member.name not in mark_offset: - # ret += f"\n\tuint32_t _{member.name}_offset;" - # mark_offset.add(member.name) - # for tok in member.valexpr: - # if ( - # isinstance(tok, ExprVal) - # and tok.name.startswith("&") - # and tok.name[1:] not in mark_offset - # ): - # ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" - # mark_offset.add(tok.name[1:]) - - # prefix0 = "\treturn " - # prefix1 = "\t || " - # prefix2 = "\t " - - # struct_versions = typ.members[0].ver - # prefix = prefix0 - # for member in typ.members: - # ret += f"\n{prefix}" - # if member.ver != struct_versions: - # ret += "( " + c_ver_cond(idprefix, member.ver) + " && " - # if member.name in mark_offset: - # ret += f"({{ _{member.name}_offset = ctx->net_offset; " - # if member.cnt: - # ret += "({" - # ret += f"\n{prefix2}\tbool err = false;" - # ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" - # if typ.name in ["d", "s"]: # SPECIAL - # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);" - # else: - # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" - # ret += f"\n{prefix2}\terr;" - # ret += f"\n{prefix2}}})" - # elif member.valexpr: - # assert member.static_size - # ret += ( - # f"({{ ctx->net_offset += {member.static_size}; false; }})" - # ) - # else: - # ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" - # if member.name in mark_offset: - # ret += "; })" - # if member.ver != struct_versions: - # ret += " )" - # prefix = prefix1 - - # for member in typ.members: - # if member.valexpr: - # assert member.static_size - # ret += f"\n{prefix}" - # ret += f"({{ encode_u{member.static_size*8}le(" - # for tok in member.valexpr: - # match tok: - # case ExprOp(): - # ret += f" {tok.op}" - # case ExprVal(name="end"): - # ret += " ctx->net_offset" - # case ExprVal(): - # ret += f" _{tok.name[1:]}_offset" - # ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})" - - # ret += ";\n" - # ret += "}\n" - - # # vtables ################################################################## - # ret += f""" - # /* vtables ********************************************************************/ - - # #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ - # .basesize = sizeof(struct {idprefix}msg_##typ), \\ - # .validate = validate_##typ, \\ - # .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ - # .marshal = (_marshal_fn_t)marshal_##typ, \\ - # }} - - # struct _vtable_version _{idprefix}vtables[{c_ver_enum(idprefix, 'NUM')}] = {{ - # """ - - # ret += f"\t[{c_ver_enum(idprefix, 'unknown')}] = {{ .msgs = {{\n" - # for msg in just_structs_msg(typs): - # if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL - # ret += f"\t\t_MSG({msg.name}),\n" - # ret += "\t}},\n" - - # for ver in sorted(versions): - # ret += f"\t[{c_ver_enum(idprefix, ver)}] = {{ .msgs = {{\n" - # for msg in just_structs_msg(typs): - # if ver not in msg.msgver: - # continue - # ret += f"\t\t_MSG({msg.name}),\n" - # ret += "\t}},\n" - # ret += "};\n" + # unmarshal_* ############################################################## + ret += """ +/* unmarshal_* ****************************************************************/ + +static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { + *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 1; +} + +static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { + *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 2; +} + +static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { + *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 4; +} + +static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { + *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 8; +} +""" + for typ in typs: + inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"static {inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + match typ: + case Number(): + ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" + case Bitfield(): + ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" + case Struct(): + ret += "\tmemset(out, 0, sizeof(*out));\n" + + for member in typ.members: + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + if member.val: + ret += f"\tctx->net_offset += {member.static_size};\n" + continue + ret += "\t" + + prefix = "\t" + if member.in_versions != typ.in_versions: + ret += "if ( " + c_ver_cond(member.in_versions) + " ) " + prefix = "\t\t" + if member.cnt: + if member.in_versions != typ.in_versions: + ret += "{\n" + ret += prefix + ret += f"out->{member.name} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" + ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + if typ.name in ["d", "s"]: # SPECIAL + ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" + else: + ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + if member.in_versions != typ.in_versions: + ret += "\t}\n" + else: + ret += ( + f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" + ) + if typ.name == "s": # SPECIAL + ret += "\tctx->extra++;\n" + ret += "\tout->utf8[out->len] = '\\0';\n" + ret += ifdef_pop(1) + ret += "}\n" + ret += ifdef_pop(0) + + # marshal_* ################################################################ + ret += """ +/* marshal_* ******************************************************************/ + +static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { + lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", + (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", + ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), + ctx->ctx->max_msg_size); + return true; +} + +static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { + if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) + return _marshal_too_large(ctx); + ctx->net_bytes[ctx->net_offset] = *val; + ctx->net_offset += 1; + return false; +} + +static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { + if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) + return _marshal_too_large(ctx); + encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 2; + return false; +} + +static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { + if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) + return true; + encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 4; + return false; +} + +static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { + if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) + return true; + encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 8; + return false; +} +""" + for typ in typs: + inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"static {inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n" + match typ: + case Number(): + ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n" + case Bitfield(): + ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n" + ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n" + case Struct(): + if len(typ.members) == 0: + ret += "\treturn false;\n" + ret += "}\n" + continue + + # Pass 1 - declare offset variables + mark_offset = set() + for member in typ.members: + if member.val: + if member.name not in mark_offset: + ret += f"\tuint32_t _{member.name}_offset;\n" + mark_offset.add(member.name) + for tok in member.val.tokens: + if isinstance(tok, ExprSym) and tok.name.startswith("&"): + if tok.name[1:] not in mark_offset: + ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" + mark_offset.add(tok.name[1:]) + + # Pass 2 - main pass + ret += "\treturn false\n" + for member in typ.members: + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += "\t || " + if member.in_versions != typ.in_versions: + ret += "( " + c_ver_cond(member.in_versions) + " && " + if member.name in mark_offset: + ret += f"({{ _{member.name}_offset = ctx->net_offset; " + if member.cnt: + ret += "({ bool err = false;\n" + ret += f"\t for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n" + ret += "\t \terr = " + if typ.name in ["d", "s"]: # SPECIAL + # Special-case is that we cast from `char` to `uint8_t`. + ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n" + else: + ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n" + ret += f"\t err; }})" + elif member.val: + # Just increment net_offset, don't actually marsha anything (yet). + assert member.static_size + ret += ( + f"({{ ctx->net_offset += {member.static_size}; false; }})" + ) + else: + ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + if member.name in mark_offset: + ret += "; })" + if member.in_versions != typ.in_versions: + ret += " )" + ret += "\n" + + # Pass 3 - marshal ,val= members + for member in typ.members: + if member.val: + assert member.static_size + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += f"\t || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n" + + ret += ifdef_pop(1) + ret += "\t ;\n" + ret += "}\n" + + # vtables ################################################################## + ret += f""" +/* vtables ********************************************************************/ + +#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ + .basesize = sizeof(struct {idprefix}msg_##typ), \\ + .validate = validate_##typ, \\ + .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ + .marshal = (_marshal_fn_t)marshal_##typ, \\ + }} + +struct _vtable_version _{idprefix}vtables[{c_ver_enum('NUM')}] = {{ +""" + + ret += f"\t[{c_ver_enum('unknown')}] = {{ .msgs = {{\n" + for msg in [msg for msg in typs if isinstance(msg, Message)]: + if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL + ret += f"\t\t_MSG({msg.name}),\n" + ret += "\t}},\n" + + for ver in sorted(versions): + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {{ .msgs = {{\n" + for msg in [msg for msg in typs if isinstance(msg, Message)]: + if ver not in msg.in_versions: + continue + ret += f"\t\t_MSG({msg.name}),\n" + ret += "\t}},\n" + ret += ifdef_pop(0) + ret += "};\n" ############################################################################ return ret @@ -1165,6 +1170,6 @@ if __name__ == "__main__": versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: - fh.write(gen_h("lib9p_", versions, typs)) + fh.write(gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: - fh.write(gen_c("lib9p_", versions, typs)) + fh.write(gen_c(versions, typs)) |