diff options
-rw-r--r-- | lib9p/9P2000.txt | 6 | ||||
-rwxr-xr-x | lib9p/9p.gen | 167 | ||||
-rw-r--r-- | lib9p/9p.generated.c | 70 |
3 files changed, 176 insertions, 67 deletions
diff --git a/lib9p/9P2000.txt b/lib9p/9P2000.txt index ab898ec..5302f2b 100644 --- a/lib9p/9P2000.txt +++ b/lib9p/9P2000.txt @@ -90,7 +90,7 @@ bitfield qt 8 qid = "type[qt] vers[4] path[8]" # stat (TODO) -stat = "stat_size[2]" +stat = "stat_size[2,val=end-&kern_type]" "kern_type[2]" "kern_dev[4]" "file_qid[qid]" @@ -150,6 +150,6 @@ bitfield o 8 122/Tremove = "fid[4]" 123/Rremove = "" 124/Tstat = "fid[4]" -125/Rstat = "stat[stat]" -126/Twstat = "fid[4] stat[stat]" +125/Rstat = "nstat[2,val=end-&stat] stat[stat]" # See the "BUG" note in the RFC for the nstat field +126/Twstat = "fid[4] nstat[2,val=end-&stat] stat[stat]" # See the "BUG" note in the RFC for the nstat field 127/Rwstat = "" diff --git a/lib9p/9p.gen b/lib9p/9p.gen index 83bdcfe..816ec0a 100755 --- a/lib9p/9p.gen +++ b/lib9p/9p.gen @@ -8,7 +8,7 @@ import enum import os.path import re -from typing import Callable, Sequence +from typing import Callable, Literal, Sequence # This strives to be "general-purpose" in that it just acts on the # *.txt inputs; but (unfortunately?) there are a few special-cases in @@ -84,6 +84,20 @@ class Struct: return size +class ExprVal: + name: str + + def __init__(self, name: str) -> None: + self.name = name + + +class ExprOp: + op: Literal["-", "+"] + + def __init__(self, op: Literal["-", "+"]) -> None: + self.op = op + + # `cnt*(name[typ])` # the `cnt*(...)` wrapper is optional class Member: @@ -91,6 +105,7 @@ class Member: name: str typ: Atom | Bitfield | Struct max: int | None = None + valexpr: list[ExprVal | ExprOp] = [] ver: set[str] @property @@ -100,8 +115,21 @@ class Member: return self.typ.static_size +def parse_valexpr(valexpr: str) -> list[ExprVal | ExprOp]: + ret: list[ExprVal | ExprOp] = [] + for tok in re.split("([-+])", valexpr): + match tok: + case "-": + ret += [ExprOp(tok)] + case "+": + ret += [ExprOp(tok)] + case _: + ret += [ExprVal(tok)] + return ret + + re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" -re_memberspec = f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>[^,]*)(?:,max=(?P<max>[0-9]+))?\\]\\)?" +re_memberspec = f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>[^,]*)(?:,max=(?P<max>[0-9]+)|,val=(?P<val>[-+&a-zA-Z0-9_]+))*\\]\\)?" def parse_members( @@ -136,11 +164,14 @@ def parse_members( if maxstr := m.group("max"): if (not isinstance(member.typ, Atom)) or member.cnt: - raise ValueError( - f"',max=' may only be specified on a non-repeated atom" - ) + raise ValueError("',max=' may only be specified on a non-repeated atom") member.max = int(maxstr) + if valstr := m.group("val"): + if (not isinstance(member.typ, Atom)) or member.cnt: + raise ValueError("',val=' may only be specified on a non-repeated atom") + member.valexpr = parse_valexpr(valstr) + ret += [member] return ret @@ -312,6 +343,16 @@ def parse_file( raise SyntaxError("must have exactly 1 version line") typs = [x for x in env.values() if not isinstance(x, Atom)] + + for typ in just_structs_all(typs): + valid_vals = ["end", *["&" + m.name for m in typ.members]] + for member in typ.members: + for tok in member.valexpr: + if isinstance(tok, ExprVal) and tok.name not in valid_vals: + raise ValueError( + f"{typ.name}.{member.name}: invalid val: {tok.name}" + ) + return version, typs @@ -436,7 +477,7 @@ enum {idprefix}version {{ ret += "\n" ret += c_typename(idprefix, struct) + " {\n" for member in struct.members: - if struct.name == "stat" and member.name == "stat_size": # SPECIAL + if member.valexpr: continue ctype = c_typename(idprefix, member.typ) if (struct.name in ["d", "s"]) and member.cnt: # SPECIAL @@ -480,6 +521,8 @@ enum {idprefix}version {{ namewidth = max(len(m.name) for m in msg.members) for member in msg.members: + if member.valexpr: + continue ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment @@ -665,8 +708,19 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += "}\n" continue - if typ.name == "stat": # SPECIAL - ret += f"\n\tuint32_t size_offset = ctx->net_offset;" + for member in typ.members: + if member.max or member.valexpr: + ret += f"\n\t{c_typename(idprefix, member.typ)} {member.name};" + mark_offset: set[str] = set() + for member in typ.members: + for tok in member.valexpr: + if ( + isinstance(tok, ExprVal) + and tok.name.startswith("&") + and tok.name[1:] not in mark_offset + ): + ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" + mark_offset.add(tok.name[1:]) prefix0 = "\treturn " prefix1 = "\t || " @@ -684,22 +738,41 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, assert prev_size ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" else: + if member.max or member.valexpr: + ret += "(" + if member.name in mark_offset: + ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" - if member.max: - assert member.static_size - ret += f"\n{prefix1}(decode_u{member.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{member.static_size}]) > ({c_typename(idprefix, member.typ)})({member.max})" - ret += f'\n{prefix2}\t? lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%"PRIu{member.static_size*8}" > %"PRIu{member.static_size*8}")",' - ret += f"\n{prefix2}\t\tdecode_u{member.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{member.static_size}]), ({c_typename(idprefix, member.typ)})({member.max}))" - ret += f"\n{prefix2}\t: false)" + if member.name in mark_offset: + ret += "; })" + if member.max or member.valexpr: + bytes = member.static_size + assert bytes + bits = bytes * 8 + ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" + if member.max: + ret += f"\n{prefix1}" + ret += f'({member.name} > UINT{bits}_C({member.max}) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu{bits}" > %"PRIu{bits}")", {member.name}, UINT{bits}_C({member.max})))' if member.ver != struct_versions: ret += " )" prefix = prefix1 prev_size = member.static_size - if typ.name == "stat": # SPECIAL - assert typ.members[0].static_size - ret += f"\n{prefix1}((uint32_t)decode_u{typ.members[0].static_size*8}le(&ctx->net_bytes[size_offset]) != ctx->net_offset - (size_offset+{typ.members[0].static_size})" - ret += f'\n{prefix2}\t? lib9p_error(ctx->ctx, LINUX_EBADMSG, "stat size does not match stat contents")' - ret += f"\n{prefix2}\t: false)" + + for member in typ.members: + if member.valexpr: + ret += f"\n{prefix}" + ret += f"({{ uint32_t correct =" + for tok in member.valexpr: + match tok: + case ExprOp(): + ret += f" {tok.op}" + case ExprVal(name="end"): + ret += " ctx->net_offset" + case ExprVal(): + ret += f" _{tok.name[1:]}_offset" + ret += f"; (((uint32_t){member.name}) != correct) &&" + ret += f'\n{prefix2}lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, correct); }})' + ret += ";\n" ret += "}\n" @@ -745,7 +818,7 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) if typ.members: struct_versions = typ.members[0].ver for member in typ.members: - if typ.name == "stat" and member.name == "stat_size": # SPECIAL + if member.valexpr: ret += f"\tctx->net_offset += {member.static_size};\n" continue ret += "\t" @@ -837,8 +910,20 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { ret += "}\n" continue - if typ.name == "stat": # SPECIAL - ret += "\n\tuint32_t size_offset = ctx->net_offset;" + mark_offset = set() + for member in typ.members: + if member.valexpr: + if member.name not in mark_offset: + ret += f"\n\tuint32_t _{member.name}_offset;" + mark_offset.add(member.name) + for tok in member.valexpr: + if ( + isinstance(tok, ExprVal) + and tok.name.startswith("&") + and tok.name[1:] not in mark_offset + ): + ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" + mark_offset.add(tok.name[1:]) prefix0 = "\treturn " prefix1 = "\t || " @@ -847,16 +932,11 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { struct_versions = typ.members[0].ver prefix = prefix0 for member in typ.members: - if typ.name == "stat" and member.name == "stat_size": # SPECIAL: - assert member.static_size - ret += f"\n{prefix }(ctx->net_offset + {member.static_size} > ctx->ctx->max_msg_size" - ret += f"\n{prefix2}\t? _marshal_too_large(ctx)" - ret += f"\n{prefix2}\t: ({{ ctx->net_offset += {member.static_size}; false; }}))" - prefix = prefix1 - continue ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " + if member.name in mark_offset: + ret += f"({{ _{member.name}_offset = ctx->net_offset; " if member.cnt: ret += "({" ret += f"\n{prefix2}\tbool err = false;" @@ -867,17 +947,34 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" ret += f"\n{prefix2}\terr;" ret += f"\n{prefix2}}})" + elif member.valexpr: + assert member.static_size + ret += ( + f"({{ ctx->net_offset += {member.static_size}; false; }})" + ) else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + if member.name in mark_offset: + ret += "; })" if member.ver != struct_versions: ret += " )" prefix = prefix1 - if typ.name == "stat": # SPECIAL - assert typ.members[0].static_size - ret += f"\n{prefix1}((ctx->net_offset - (size_offset+{typ.members[0].static_size}) > UINT16_MAX)" - ret += f'\n{prefix2}\t? lib9p_error(ctx->ctx, LINUX_ERANGE, "stat object too large")' - ret += f"\n{prefix2}\t: ({{ encode_u{typ.members[0].static_size*8}le((uint{typ.members[0].static_size*8}_t)(ctx->net_offset - (size_offset+{typ.members[0].static_size})), &ctx->net_bytes[size_offset]);" - ret += f"\n{prefix2} false; }}))" + + for member in typ.members: + if member.valexpr: + assert member.static_size + ret += f"\n{prefix}" + ret += f"({{ encode_u{member.static_size*8}le(" + for tok in member.valexpr: + match tok: + case ExprOp(): + ret += f" {tok.op}" + case ExprVal(name="end"): + ret += " ctx->net_offset" + case ExprVal(): + ret += f" _{tok.name[1:]}_offset" + ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})" + ret += ";\n" ret += "}\n" diff --git a/lib9p/9p.generated.c b/lib9p/9p.generated.c index b2243e2..b24c6e8 100644 --- a/lib9p/9p.generated.c +++ b/lib9p/9p.generated.c @@ -383,9 +383,10 @@ static ALWAYS_INLINE bool validate_qid(struct _validate_ctx *ctx) { } static ALWAYS_INLINE bool validate_stat(struct _validate_ctx *ctx) { - uint32_t size_offset = ctx->net_offset; - return validate_2(ctx) - || validate_2(ctx) + uint16_t stat_size; + uint32_t _kern_type_offset; + return (validate_2(ctx) || ({ stat_size = decode_u16le(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || ({ _kern_type_offset = ctx->net_offset; validate_2(ctx); }) || validate_4(ctx) || validate_qid(ctx) || validate_dm(ctx) @@ -400,9 +401,8 @@ static ALWAYS_INLINE bool validate_stat(struct _validate_ctx *ctx) { || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && validate_4(ctx) ) || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && validate_4(ctx) ) || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && validate_4(ctx) ) - || ((uint32_t)decode_u16le(&ctx->net_bytes[size_offset]) != ctx->net_offset - (size_offset+2) - ? lib9p_error(ctx->ctx, LINUX_EBADMSG, "stat size does not match stat contents") - : false); + || ({ uint32_t correct = ctx->net_offset - _kern_type_offset; (((uint32_t)stat_size) != correct) && + lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "stat_size value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)stat_size, correct); }); } static ALWAYS_INLINE bool validate_o(struct _validate_ctx *ctx) { @@ -463,22 +463,18 @@ static FLATTEN bool validate_Rflush(struct _validate_ctx *UNUSED(ctx)) { } static FLATTEN bool validate_Twalk(struct _validate_ctx *ctx) { + uint16_t nwname; return validate_4(ctx) || validate_4(ctx) - || validate_2(ctx) - || (decode_u16le(&ctx->net_bytes[ctx->net_offset-2]) > (uint16_t)(16) - ? lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%"PRIu16" > %"PRIu16")", - decode_u16le(&ctx->net_bytes[ctx->net_offset-2]), (uint16_t)(16)) - : false) + || (validate_2(ctx) || ({ nwname = decode_u16le(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || (nwname > UINT16_C(16) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "nwname value too large (%"PRIu16" > %"PRIu16")", nwname, UINT16_C(16))) || _validate_list(ctx, decode_u16le(&ctx->net_bytes[ctx->net_offset-2]), validate_s, sizeof(struct lib9p_s)); } static FLATTEN bool validate_Rwalk(struct _validate_ctx *ctx) { - return validate_2(ctx) - || (decode_u16le(&ctx->net_bytes[ctx->net_offset-2]) > (uint16_t)(16) - ? lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%"PRIu16" > %"PRIu16")", - decode_u16le(&ctx->net_bytes[ctx->net_offset-2]), (uint16_t)(16)) - : false) + uint16_t nwqid; + return (validate_2(ctx) || ({ nwqid = decode_u16le(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || (nwqid > UINT16_C(16) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "nwqid value too large (%"PRIu16" > %"PRIu16")", nwqid, UINT16_C(16))) || _validate_list(ctx, decode_u16le(&ctx->net_bytes[ctx->net_offset-2]), validate_qid, sizeof(struct lib9p_qid)); } @@ -545,12 +541,22 @@ static FLATTEN bool validate_Tstat(struct _validate_ctx *ctx) { } static FLATTEN bool validate_Rstat(struct _validate_ctx *ctx) { - return validate_stat(ctx); + uint16_t nstat; + uint32_t _stat_offset; + return (validate_2(ctx) || ({ nstat = decode_u16le(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || ({ _stat_offset = ctx->net_offset; validate_stat(ctx); }) + || ({ uint32_t correct = ctx->net_offset - _stat_offset; (((uint32_t)nstat) != correct) && + lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "nstat value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)nstat, correct); }); } static FLATTEN bool validate_Twstat(struct _validate_ctx *ctx) { + uint16_t nstat; + uint32_t _stat_offset; return validate_4(ctx) - || validate_stat(ctx); + || (validate_2(ctx) || ({ nstat = decode_u16le(&ctx->net_bytes[ctx->net_offset-2]); false; })) + || ({ _stat_offset = ctx->net_offset; validate_stat(ctx); }) + || ({ uint32_t correct = ctx->net_offset - _stat_offset; (((uint32_t)nstat) != correct) && + lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "nstat value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t)nstat, correct); }); } static FLATTEN bool validate_Rwstat(struct _validate_ctx *UNUSED(ctx)) { @@ -816,12 +822,14 @@ static FLATTEN void unmarshal_Tstat(struct _unmarshal_ctx *ctx, struct lib9p_msg static FLATTEN void unmarshal_Rstat(struct _unmarshal_ctx *ctx, struct lib9p_msg_Rstat *out) { memset(out, 0, sizeof(*out)); + ctx->net_offset += 2; unmarshal_stat(ctx, &out->stat); } static FLATTEN void unmarshal_Twstat(struct _unmarshal_ctx *ctx, struct lib9p_msg_Twstat *out) { memset(out, 0, sizeof(*out)); unmarshal_4(ctx, &out->fid); + ctx->net_offset += 2; unmarshal_stat(ctx, &out->stat); } @@ -946,11 +954,10 @@ static ALWAYS_INLINE bool marshal_qid(struct _marshal_ctx *ctx, struct lib9p_qid } static ALWAYS_INLINE bool marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val) { - uint32_t size_offset = ctx->net_offset; - return (ctx->net_offset + 2 > ctx->ctx->max_msg_size - ? _marshal_too_large(ctx) - : ({ ctx->net_offset += 2; false; })) - || marshal_2(ctx, &val->kern_type) + uint32_t _stat_size_offset; + uint32_t _kern_type_offset; + return ({ _stat_size_offset = ctx->net_offset; ({ ctx->net_offset += 2; false; }); }) + || ({ _kern_type_offset = ctx->net_offset; marshal_2(ctx, &val->kern_type); }) || marshal_4(ctx, &val->kern_dev) || marshal_qid(ctx, &val->file_qid) || marshal_dm(ctx, &val->file_mode) @@ -965,10 +972,7 @@ static ALWAYS_INLINE bool marshal_stat(struct _marshal_ctx *ctx, struct lib9p_st || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_owner_n_uid) ) || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_owner_n_gid) ) || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_last_modified_n_uid) ) - || ((ctx->net_offset - (size_offset+2) > UINT16_MAX) - ? lib9p_error(ctx->ctx, LINUX_ERANGE, "stat object too large") - : ({ encode_u16le((uint16_t)(ctx->net_offset - (size_offset+2)), &ctx->net_bytes[size_offset]); - false; })); + || ({ encode_u16le( ctx->net_offset - _kern_type_offset, &ctx->net_bytes[_stat_size_offset]); false; }); } static ALWAYS_INLINE bool marshal_o(struct _marshal_ctx *ctx, lib9p_o_t *val) { @@ -1106,12 +1110,20 @@ static FLATTEN bool marshal_Tstat(struct _marshal_ctx *ctx, struct lib9p_msg_Tst } static FLATTEN bool marshal_Rstat(struct _marshal_ctx *ctx, struct lib9p_msg_Rstat *val) { - return marshal_stat(ctx, &val->stat); + uint32_t _nstat_offset; + uint32_t _stat_offset; + return ({ _nstat_offset = ctx->net_offset; ({ ctx->net_offset += 2; false; }); }) + || ({ _stat_offset = ctx->net_offset; marshal_stat(ctx, &val->stat); }) + || ({ encode_u16le( ctx->net_offset - _stat_offset, &ctx->net_bytes[_nstat_offset]); false; }); } static FLATTEN bool marshal_Twstat(struct _marshal_ctx *ctx, struct lib9p_msg_Twstat *val) { + uint32_t _nstat_offset; + uint32_t _stat_offset; return marshal_4(ctx, &val->fid) - || marshal_stat(ctx, &val->stat); + || ({ _nstat_offset = ctx->net_offset; ({ ctx->net_offset += 2; false; }); }) + || ({ _stat_offset = ctx->net_offset; marshal_stat(ctx, &val->stat); }) + || ({ encode_u16le( ctx->net_offset - _stat_offset, &ctx->net_bytes[_nstat_offset]); false; }); } static FLATTEN bool marshal_Rwstat(struct _marshal_ctx *UNUSED(ctx), struct lib9p_msg_Rwstat *UNUSED(val)) { |