diff options
Diffstat (limited to 'lib9p/9p.gen')
-rwxr-xr-x | lib9p/9p.gen | 167 |
1 files changed, 132 insertions, 35 deletions
diff --git a/lib9p/9p.gen b/lib9p/9p.gen index 83bdcfe..816ec0a 100755 --- a/lib9p/9p.gen +++ b/lib9p/9p.gen @@ -8,7 +8,7 @@ import enum import os.path import re -from typing import Callable, Sequence +from typing import Callable, Literal, Sequence # This strives to be "general-purpose" in that it just acts on the # *.txt inputs; but (unfortunately?) there are a few special-cases in @@ -84,6 +84,20 @@ class Struct: return size +class ExprVal: + name: str + + def __init__(self, name: str) -> None: + self.name = name + + +class ExprOp: + op: Literal["-", "+"] + + def __init__(self, op: Literal["-", "+"]) -> None: + self.op = op + + # `cnt*(name[typ])` # the `cnt*(...)` wrapper is optional class Member: @@ -91,6 +105,7 @@ class Member: name: str typ: Atom | Bitfield | Struct max: int | None = None + valexpr: list[ExprVal | ExprOp] = [] ver: set[str] @property @@ -100,8 +115,21 @@ class Member: return self.typ.static_size +def parse_valexpr(valexpr: str) -> list[ExprVal | ExprOp]: + ret: list[ExprVal | ExprOp] = [] + for tok in re.split("([-+])", valexpr): + match tok: + case "-": + ret += [ExprOp(tok)] + case "+": + ret += [ExprOp(tok)] + case _: + ret += [ExprVal(tok)] + return ret + + re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" -re_memberspec = f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>[^,]*)(?:,max=(?P<max>[0-9]+))?\\]\\)?" +re_memberspec = f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>[^,]*)(?:,max=(?P<max>[0-9]+)|,val=(?P<val>[-+&a-zA-Z0-9_]+))*\\]\\)?" def parse_members( @@ -136,11 +164,14 @@ def parse_members( if maxstr := m.group("max"): if (not isinstance(member.typ, Atom)) or member.cnt: - raise ValueError( - f"',max=' may only be specified on a non-repeated atom" - ) + raise ValueError("',max=' may only be specified on a non-repeated atom") member.max = int(maxstr) + if valstr := m.group("val"): + if (not isinstance(member.typ, Atom)) or member.cnt: + raise ValueError("',val=' may only be specified on a non-repeated atom") + member.valexpr = parse_valexpr(valstr) + ret += [member] return ret @@ -312,6 +343,16 @@ def parse_file( raise SyntaxError("must have exactly 1 version line") typs = [x for x in env.values() if not isinstance(x, Atom)] + + for typ in just_structs_all(typs): + valid_vals = ["end", *["&" + m.name for m in typ.members]] + for member in typ.members: + for tok in member.valexpr: + if isinstance(tok, ExprVal) and tok.name not in valid_vals: + raise ValueError( + f"{typ.name}.{member.name}: invalid val: {tok.name}" + ) + return version, typs @@ -436,7 +477,7 @@ enum {idprefix}version {{ ret += "\n" ret += c_typename(idprefix, struct) + " {\n" for member in struct.members: - if struct.name == "stat" and member.name == "stat_size": # SPECIAL + if member.valexpr: continue ctype = c_typename(idprefix, member.typ) if (struct.name in ["d", "s"]) and member.cnt: # SPECIAL @@ -480,6 +521,8 @@ enum {idprefix}version {{ namewidth = max(len(m.name) for m in msg.members) for member in msg.members: + if member.valexpr: + continue ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment @@ -665,8 +708,19 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += "}\n" continue - if typ.name == "stat": # SPECIAL - ret += f"\n\tuint32_t size_offset = ctx->net_offset;" + for member in typ.members: + if member.max or member.valexpr: + ret += f"\n\t{c_typename(idprefix, member.typ)} {member.name};" + mark_offset: set[str] = set() + for member in typ.members: + for tok in member.valexpr: + if ( + isinstance(tok, ExprVal) + and tok.name.startswith("&") + and tok.name[1:] not in mark_offset + ): + ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" + mark_offset.add(tok.name[1:]) prefix0 = "\treturn " prefix1 = "\t || " @@ -684,22 +738,41 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, assert prev_size ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" else: + if member.max or member.valexpr: + ret += "(" + if member.name in mark_offset: + ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" - if member.max: - assert member.static_size - ret += f"\n{prefix1}(decode_u{member.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{member.static_size}]) > ({c_typename(idprefix, member.typ)})({member.max})" - ret += f'\n{prefix2}\t? lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%"PRIu{member.static_size*8}" > %"PRIu{member.static_size*8}")",' - ret += f"\n{prefix2}\t\tdecode_u{member.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{member.static_size}]), ({c_typename(idprefix, member.typ)})({member.max}))" - ret += f"\n{prefix2}\t: false)" + if member.name in mark_offset: + ret += "; })" + if member.max or member.valexpr: + bytes = member.static_size + assert bytes + bits = bytes * 8 + ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" + if member.max: + ret += f"\n{prefix1}" + ret += f'({member.name} > UINT{bits}_C({member.max}) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu{bits}" > %"PRIu{bits}")", {member.name}, UINT{bits}_C({member.max})))' if member.ver != struct_versions: ret += " )" prefix = prefix1 prev_size = member.static_size - if typ.name == "stat": # SPECIAL - assert typ.members[0].static_size - ret += f"\n{prefix1}((uint32_t)decode_u{typ.members[0].static_size*8}le(&ctx->net_bytes[size_offset]) != ctx->net_offset - (size_offset+{typ.members[0].static_size})" - ret += f'\n{prefix2}\t? lib9p_error(ctx->ctx, LINUX_EBADMSG, "stat size does not match stat contents")' - ret += f"\n{prefix2}\t: false)" + + for member in typ.members: + if member.valexpr: + ret += f"\n{prefix}" + ret += f"({{ uint32_t correct =" + for tok in member.valexpr: + match tok: + case ExprOp(): + ret += f" {tok.op}" + case ExprVal(name="end"): + ret += " ctx->net_offset" + case ExprVal(): + ret += f" _{tok.name[1:]}_offset" + ret += f"; (((uint32_t){member.name}) != correct) &&" + ret += f'\n{prefix2}lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, correct); }})' + ret += ";\n" ret += "}\n" @@ -745,7 +818,7 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) if typ.members: struct_versions = typ.members[0].ver for member in typ.members: - if typ.name == "stat" and member.name == "stat_size": # SPECIAL + if member.valexpr: ret += f"\tctx->net_offset += {member.static_size};\n" continue ret += "\t" @@ -837,8 +910,20 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { ret += "}\n" continue - if typ.name == "stat": # SPECIAL - ret += "\n\tuint32_t size_offset = ctx->net_offset;" + mark_offset = set() + for member in typ.members: + if member.valexpr: + if member.name not in mark_offset: + ret += f"\n\tuint32_t _{member.name}_offset;" + mark_offset.add(member.name) + for tok in member.valexpr: + if ( + isinstance(tok, ExprVal) + and tok.name.startswith("&") + and tok.name[1:] not in mark_offset + ): + ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" + mark_offset.add(tok.name[1:]) prefix0 = "\treturn " prefix1 = "\t || " @@ -847,16 +932,11 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { struct_versions = typ.members[0].ver prefix = prefix0 for member in typ.members: - if typ.name == "stat" and member.name == "stat_size": # SPECIAL: - assert member.static_size - ret += f"\n{prefix }(ctx->net_offset + {member.static_size} > ctx->ctx->max_msg_size" - ret += f"\n{prefix2}\t? _marshal_too_large(ctx)" - ret += f"\n{prefix2}\t: ({{ ctx->net_offset += {member.static_size}; false; }}))" - prefix = prefix1 - continue ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " + if member.name in mark_offset: + ret += f"({{ _{member.name}_offset = ctx->net_offset; " if member.cnt: ret += "({" ret += f"\n{prefix2}\tbool err = false;" @@ -867,17 +947,34 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" ret += f"\n{prefix2}\terr;" ret += f"\n{prefix2}}})" + elif member.valexpr: + assert member.static_size + ret += ( + f"({{ ctx->net_offset += {member.static_size}; false; }})" + ) else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + if member.name in mark_offset: + ret += "; })" if member.ver != struct_versions: ret += " )" prefix = prefix1 - if typ.name == "stat": # SPECIAL - assert typ.members[0].static_size - ret += f"\n{prefix1}((ctx->net_offset - (size_offset+{typ.members[0].static_size}) > UINT16_MAX)" - ret += f'\n{prefix2}\t? lib9p_error(ctx->ctx, LINUX_ERANGE, "stat object too large")' - ret += f"\n{prefix2}\t: ({{ encode_u{typ.members[0].static_size*8}le((uint{typ.members[0].static_size*8}_t)(ctx->net_offset - (size_offset+{typ.members[0].static_size})), &ctx->net_bytes[size_offset]);" - ret += f"\n{prefix2} false; }}))" + + for member in typ.members: + if member.valexpr: + assert member.static_size + ret += f"\n{prefix}" + ret += f"({{ encode_u{member.static_size*8}le(" + for tok in member.valexpr: + match tok: + case ExprOp(): + ret += f" {tok.op}" + case ExprVal(name="end"): + ret += " ctx->net_offset" + case ExprVal(): + ret += f" _{tok.name[1:]}_offset" + ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})" + ret += ";\n" ret += "}\n" |