diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-10-02 13:55:08 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-10-02 13:55:08 -0600 |
commit | b7a140a42f272aadb5400f013cd21a7a218d26e8 (patch) | |
tree | 6770864edd0330704b43d4b2c8f9d8240f691678 /lib9p | |
parent | fa7f6a5176a386f8847810f34538d5c0e56f0fb1 (diff) |
wip validate, bitfield
Diffstat (limited to 'lib9p')
-rw-r--r-- | lib9p/9p.c | 4 | ||||
-rw-r--r-- | lib9p/srv.c | 10 | ||||
-rw-r--r-- | lib9p/types.c | 12 | ||||
-rwxr-xr-x | lib9p/types.gen | 245 |
4 files changed, 165 insertions, 106 deletions
@@ -68,7 +68,7 @@ ssize_t lib9p_validate(struct lib9p_ctx *ctx, uint8_t *net_bytes) { return -1; assert(subctx.net_offset <= subctx.net_size); if (subctx.net_offset < subctx.net_size) - return lib9p_error(ctx, LINUX_EBADMSG, "message has %"PRIu32" extra bytes", + return lib9p_errorf(ctx, LINUX_EBADMSG, "message has %"PRIu32" extra bytes", subctx.net_size - subctx.net_offset); ssize_t ret; @@ -92,7 +92,7 @@ void lib9p_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes, /* Body */ struct _vtable_msg vtable = _lib9p_vtables[ctx->version].msgs[*ret_typ]; - subctx.extra = ret_body + vtable.unmarshal_basesize; + subctx.extra = ret_body + vtable.basesize; vtable.unmarshal(&subctx, ret_body); } diff --git a/lib9p/srv.c b/lib9p/srv.c index af713e2..f25fe09 100644 --- a/lib9p/srv.c +++ b/lib9p/srv.c @@ -136,7 +136,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { nonrespond_errorf("accept: %s", strerror(-conn.fd)); continue; } - + struct lib9p_sess sess = { .parent_conn = &conn, .version = 0, @@ -222,7 +222,7 @@ COROUTINE lib9p_srv_write_cr(void *_srv) { lib9p_msg_type_str(typ)); goto write; } - ssize_t host_size = lib9p_unmarshal_size(&req.ctx, net); + ssize_t host_size = lib9p_validate(&req.ctx, net); if (host_size < 0) goto write; if ((size_t)host_size > sizeof(host_req)) { @@ -231,7 +231,7 @@ COROUTINE lib9p_srv_write_cr(void *_srv) { goto write; } lib9p_unmarshal(&req.ctx, net, &typ, &req.tag, host_req); - + /* Handle it. */ switch (typ) { case LIB9P_TYP_Tversion: @@ -307,8 +307,8 @@ static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *re enum lib9p_version version = LIB9P_VER_unknown; if (req->version.len >= 6 && - req->version.utf8[0] == '9' && - req->version.utf8[1] == 'P' && + req->version.utf8[0] == '9' && + req->version.utf8[1] == 'P' && '0' <= req->version.utf8[2] && req->version.utf8[2] <= '9' && '0' <= req->version.utf8[3] && req->version.utf8[3] <= '9' && '0' <= req->version.utf8[4] && req->version.utf8[4] <= '9' && diff --git a/lib9p/types.c b/lib9p/types.c index 8582161..cb410ea 100644 --- a/lib9p/types.c +++ b/lib9p/types.c @@ -343,6 +343,10 @@ static ALWAYS_INLINE bool validate_s(struct _validate_ctx *ctx) { return false; } +static ALWAYS_INLINE bool validate_qt(struct _validate_ctx *ctx) { + return validate_1(ctx); +} + static ALWAYS_INLINE bool validate_qid(struct _validate_ctx *ctx) { return validate_qt(ctx) || validate_4(ctx) @@ -569,6 +573,10 @@ static ALWAYS_INLINE void unmarshal_s(struct _unmarshal_ctx *ctx, struct lib9p_s unmarshal_1(ctx, &out->utf8[i]); } +static ALWAYS_INLINE void unmarshal_qt(struct _unmarshal_ctx *ctx, lib9p_qt_t *out) { + unmarshal_1(ctx, (uint8_t *)out); +} + static ALWAYS_INLINE void unmarshal_qid(struct _unmarshal_ctx *ctx, struct lib9p_qid *out) { memset(out, 0, sizeof(*out)); unmarshal_qt(ctx, &out->type); @@ -859,6 +867,10 @@ static ALWAYS_INLINE bool marshal_s(struct _marshal_ctx *ctx, struct lib9p_s *va }); } +static ALWAYS_INLINE bool marshal_qt(struct _marshal_ctx *ctx, lib9p_qt_t *val) { + return marshal_1(ctx, (uint8_t *)val); +} + static ALWAYS_INLINE bool marshal_qid(struct _marshal_ctx *ctx, struct lib9p_qid *val) { return marshal_qt(ctx, &val->type) || marshal_4(ctx, &val->vers) diff --git a/lib9p/types.gen b/lib9p/types.gen index a594e45..b41442a 100755 --- a/lib9p/types.gen +++ b/lib9p/types.gen @@ -35,10 +35,18 @@ class Atom(enum.Enum): class Bitfield: name: str bits: list[str] - aliases: dict[str, str] + names: dict[str, str] @property - def static_size(self) -> int | None: + def static_size(self) -> int: + if len(self.bits) <= 8: + return 1 + if len(self.bits) <= 16: + return 2 + if len(self.bits) <= 32: + return 4 + if len(self.bits) <= 64: + return 8 return int((len(self.bits) + 7) / 8) @@ -209,7 +217,7 @@ def parse_file( bf = Bitfield() bf.name = m.group("name") bf.bits = int(m.group("size")) * [""] - bf.aliases = {} + bf.names = {} env[bf.name] = bf prev = bf elif m := re.fullmatch(re_bitfieldspec_bit, line): @@ -232,10 +240,10 @@ def parse_file( raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") if bf.bits[bit]: raise ValueError(f"{bf.name}: bit {bit} already assigned") - if name in bf.aliases: + if name in bf.names: raise ValueError(f"{bf.name}: name {name} already assigned") bf.bits[bit] = name - bf.aliases[name] = "" + bf.names[name] = "" elif m := re.fullmatch(re_bitfieldspec_alias, line): if m.group("bitfield"): if m.group("bitfield") not in env: @@ -252,9 +260,9 @@ def parse_file( bf = prev name = m.group("name") val = m.group("val") - if name in bf.aliases: + if name in bf.names: raise ValueError(f"{bf.name}: name {name} already assigned") - bf.aliases[name] = val + bf.names[name] = val else: raise SyntaxError(f"invalid line {repr(line)}") if not version: @@ -348,10 +356,17 @@ enum {idprefix}version {{ for bf in just_bitfields(typs): ret += "\n" ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n" - vals = dict([ - *reversed([((k or f"_UNUSED_{v}"), f"1<<{v}") for (v, k) in enumerate(bf.bits)]), - *[(k, v) for (k, v) in bf.aliases.items() if v], - ]) + vals = dict( + [ + *reversed( + [ + ((k or f"_UNUSED_{v}"), f"1<<{v}") + for (v, k) in enumerate(bf.bits) + ] + ), + *[(k, v) for (k, v) in bf.names.items() if k not in bf.bits], + ] + ) namewidth = max(len(name) for name in vals) for name, val in vals.items(): ret += f"#define {idprefix.upper()}{bf.name.upper()}_{name.ljust(namewidth)} (({c_typename(idprefix, bf)})({val}))\n" @@ -505,17 +520,17 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, #define validate_4(ctx) _validate_size_net(ctx, 4) #define validate_8(ctx) _validate_size_net(ctx, 8) """ - for struct in just_structs_all(typs): - inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN" - argfn = used if struct.members else unused + for typ in typs: + inline = ( + " FLATTEN" + if (isinstance(typ, Struct) and typ.msgid is not None) + else " ALWAYS_INLINE" + ) + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" - ret += f"static{inline} bool validate_{struct.name}(struct _validate_ctx *{argfn('ctx')}) {{" - if len(struct.members) == 0: - ret += "\n\treturn false;\n" - ret += "}\n" - continue + ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{" - if struct.name == "d": # SPECIAL + if typ.name == "d": # SPECIAL # Optimize... maybe the compiler could figure out to do # this, but let's make it obvious. ret += "\n" @@ -526,7 +541,7 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" ret += "}\n" continue - if struct.name == "s": # SPECIAL + if typ.name == "s": # SPECIAL # Add an extra nul-byte on the host, and validate UTF-8 # (also, similar optimization to "d"). ret += "\n" @@ -542,27 +557,39 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += "}\n" continue - prefix0 = "\treturn " - prefix1 = "\t || " - - struct_versions = struct.members[0].ver - - prefix = prefix0 - prev_size: int | None = None - for member in struct.members: - ret += f"\n{prefix}" - if member.ver != struct_versions: - ret += "( " + c_vercond(idprefix, member.ver) + " && " - if member.cnt is not None: - assert prev_size - ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" - else: - ret += f"validate_{member.typ.name}(ctx)" - if member.ver != struct_versions: - ret += " )" - prefix = prefix1 - prev_size = member.static_size - ret += ";\n}\n" + match typ: + case Bitfield(): + ret += "\n" + # ret += "\t{c_typename(idprefix, typ)} mask = 0b" + (''.join('1' if b else '0' for b in reversed(typ.bits))) + ";\n" + ret += f"\treturn validate_{typ.static_size}(ctx);\n" + case Struct(): + if len(typ.members) == 0: + ret += "\n\treturn false;\n" + ret += "}\n" + continue + + prefix0 = "\treturn " + prefix1 = "\t || " + + struct_versions = typ.members[0].ver + + prefix = prefix0 + prev_size: int | None = None + for member in typ.members: + ret += f"\n{prefix}" + if member.ver != struct_versions: + ret += "( " + c_vercond(idprefix, member.ver) + " && " + if member.cnt is not None: + assert prev_size + ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" + else: + ret += f"validate_{member.typ.name}(ctx)" + if member.ver != struct_versions: + ret += " )" + prefix = prefix1 + prev_size = member.static_size + ret += ";\n" + ret += "}\n" # unmarshal_* ############################################################## ret += """ @@ -588,32 +615,42 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) ctx->net_offset += 8; } """ - for struct in just_structs_all(typs): - inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN" - argfn = used if struct.members else unused + for typ in typs: + inline = ( + " FLATTEN" + if (isinstance(typ, Struct) and typ.msgid is not None) + else " ALWAYS_INLINE" + ) + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" - ret += f"static{inline} void unmarshal_{struct.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *out) {{\n" - ret += "\tmemset(out, 0, sizeof(*out));\n" - - if struct.members: - struct_versions = struct.members[0].ver - for member in struct.members: - ret += "\t" - prefix = "\t" - if member.ver != struct_versions: - ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " - prefix = "\t\t" - if member.cnt: - if member.ver != struct_versions: - ret += "{\n" - ret += f"{prefix}out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" - ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" - if member.ver != struct_versions: - ret += "\t}\n" - else: - ret += f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" + ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" + match typ: + case Bitfield(): + ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" + case Struct(): + ret += "\tmemset(out, 0, sizeof(*out));\n" + + if typ.members: + struct_versions = typ.members[0].ver + for member in typ.members: + ret += "\t" + prefix = "\t" + if member.ver != struct_versions: + ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " + prefix = "\t\t" + if member.cnt: + if member.ver != struct_versions: + ret += "{\n" + ret += f"{prefix}out->{member.name} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" + ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + if member.ver != struct_versions: + ret += "\t}\n" + else: + ret += ( + f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" + ) ret += "}\n" # marshal_* ################################################################ @@ -660,39 +697,49 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { return false; } """ - for struct in just_structs_all(typs): - inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN" - argfn = used if struct.members else unused + for typ in typs: + inline = ( + " FLATTEN" + if (isinstance(typ, Struct) and typ.msgid is not None) + else " ALWAYS_INLINE" + ) + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" - ret += f"static{inline} bool marshal_{struct.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *{argfn('val')}) {{" - if len(struct.members) == 0: - ret += "\n\treturn false;\n" - ret += "}\n" - continue - - prefix0 = "\treturn " - prefix1 = "\t || " - prefix2 = "\t " - - struct_versions = struct.members[0].ver - prefix = prefix0 - for member in struct.members: - ret += f"\n{prefix}" - if member.ver != struct_versions: - ret += "( " + c_vercond(idprefix, member.ver) + " && " - if member.cnt: - ret += "({" - ret += f"\n{prefix2}\tbool err = false;" - ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" - ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" - ret += f"\n{prefix2}\terr;" - ret += f"\n{prefix2}}})" - else: - ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" - if member.ver != struct_versions: - ret += " )" - prefix = prefix1 - ret += ";\n}\n" + ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" + match typ: + case Bitfield(): + ret += "\n" + ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" + case Struct(): + if len(typ.members) == 0: + ret += "\n\treturn false;\n" + ret += "}\n" + continue + + prefix0 = "\treturn " + prefix1 = "\t || " + prefix2 = "\t " + + struct_versions = typ.members[0].ver + prefix = prefix0 + for member in typ.members: + ret += f"\n{prefix}" + if member.ver != struct_versions: + ret += "( " + c_vercond(idprefix, member.ver) + " && " + if member.cnt: + ret += "({" + ret += f"\n{prefix2}\tbool err = false;" + ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" + ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" + ret += f"\n{prefix2}\terr;" + ret += f"\n{prefix2}}})" + else: + ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + if member.ver != struct_versions: + ret += " )" + prefix = prefix1 + ret += ";\n" + ret += "}\n" # vtables ################################################################## ret += f""" |