summaryrefslogtreecommitdiff
path: root/lib9p
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-10-02 13:55:08 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-10-02 13:55:08 -0600
commitb7a140a42f272aadb5400f013cd21a7a218d26e8 (patch)
tree6770864edd0330704b43d4b2c8f9d8240f691678 /lib9p
parentfa7f6a5176a386f8847810f34538d5c0e56f0fb1 (diff)
wip validate, bitfield
Diffstat (limited to 'lib9p')
-rw-r--r--lib9p/9p.c4
-rw-r--r--lib9p/srv.c10
-rw-r--r--lib9p/types.c12
-rwxr-xr-xlib9p/types.gen245
4 files changed, 165 insertions, 106 deletions
diff --git a/lib9p/9p.c b/lib9p/9p.c
index e1303bd..af1c6ca 100644
--- a/lib9p/9p.c
+++ b/lib9p/9p.c
@@ -68,7 +68,7 @@ ssize_t lib9p_validate(struct lib9p_ctx *ctx, uint8_t *net_bytes) {
return -1;
assert(subctx.net_offset <= subctx.net_size);
if (subctx.net_offset < subctx.net_size)
- return lib9p_error(ctx, LINUX_EBADMSG, "message has %"PRIu32" extra bytes",
+ return lib9p_errorf(ctx, LINUX_EBADMSG, "message has %"PRIu32" extra bytes",
subctx.net_size - subctx.net_offset);
ssize_t ret;
@@ -92,7 +92,7 @@ void lib9p_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes,
/* Body */
struct _vtable_msg vtable = _lib9p_vtables[ctx->version].msgs[*ret_typ];
- subctx.extra = ret_body + vtable.unmarshal_basesize;
+ subctx.extra = ret_body + vtable.basesize;
vtable.unmarshal(&subctx, ret_body);
}
diff --git a/lib9p/srv.c b/lib9p/srv.c
index af713e2..f25fe09 100644
--- a/lib9p/srv.c
+++ b/lib9p/srv.c
@@ -136,7 +136,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
nonrespond_errorf("accept: %s", strerror(-conn.fd));
continue;
}
-
+
struct lib9p_sess sess = {
.parent_conn = &conn,
.version = 0,
@@ -222,7 +222,7 @@ COROUTINE lib9p_srv_write_cr(void *_srv) {
lib9p_msg_type_str(typ));
goto write;
}
- ssize_t host_size = lib9p_unmarshal_size(&req.ctx, net);
+ ssize_t host_size = lib9p_validate(&req.ctx, net);
if (host_size < 0)
goto write;
if ((size_t)host_size > sizeof(host_req)) {
@@ -231,7 +231,7 @@ COROUTINE lib9p_srv_write_cr(void *_srv) {
goto write;
}
lib9p_unmarshal(&req.ctx, net, &typ, &req.tag, host_req);
-
+
/* Handle it. */
switch (typ) {
case LIB9P_TYP_Tversion:
@@ -307,8 +307,8 @@ static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *re
enum lib9p_version version = LIB9P_VER_unknown;
if (req->version.len >= 6 &&
- req->version.utf8[0] == '9' &&
- req->version.utf8[1] == 'P' &&
+ req->version.utf8[0] == '9' &&
+ req->version.utf8[1] == 'P' &&
'0' <= req->version.utf8[2] && req->version.utf8[2] <= '9' &&
'0' <= req->version.utf8[3] && req->version.utf8[3] <= '9' &&
'0' <= req->version.utf8[4] && req->version.utf8[4] <= '9' &&
diff --git a/lib9p/types.c b/lib9p/types.c
index 8582161..cb410ea 100644
--- a/lib9p/types.c
+++ b/lib9p/types.c
@@ -343,6 +343,10 @@ static ALWAYS_INLINE bool validate_s(struct _validate_ctx *ctx) {
return false;
}
+static ALWAYS_INLINE bool validate_qt(struct _validate_ctx *ctx) {
+ return validate_1(ctx);
+}
+
static ALWAYS_INLINE bool validate_qid(struct _validate_ctx *ctx) {
return validate_qt(ctx)
|| validate_4(ctx)
@@ -569,6 +573,10 @@ static ALWAYS_INLINE void unmarshal_s(struct _unmarshal_ctx *ctx, struct lib9p_s
unmarshal_1(ctx, &out->utf8[i]);
}
+static ALWAYS_INLINE void unmarshal_qt(struct _unmarshal_ctx *ctx, lib9p_qt_t *out) {
+ unmarshal_1(ctx, (uint8_t *)out);
+}
+
static ALWAYS_INLINE void unmarshal_qid(struct _unmarshal_ctx *ctx, struct lib9p_qid *out) {
memset(out, 0, sizeof(*out));
unmarshal_qt(ctx, &out->type);
@@ -859,6 +867,10 @@ static ALWAYS_INLINE bool marshal_s(struct _marshal_ctx *ctx, struct lib9p_s *va
});
}
+static ALWAYS_INLINE bool marshal_qt(struct _marshal_ctx *ctx, lib9p_qt_t *val) {
+ return marshal_1(ctx, (uint8_t *)val);
+}
+
static ALWAYS_INLINE bool marshal_qid(struct _marshal_ctx *ctx, struct lib9p_qid *val) {
return marshal_qt(ctx, &val->type)
|| marshal_4(ctx, &val->vers)
diff --git a/lib9p/types.gen b/lib9p/types.gen
index a594e45..b41442a 100755
--- a/lib9p/types.gen
+++ b/lib9p/types.gen
@@ -35,10 +35,18 @@ class Atom(enum.Enum):
class Bitfield:
name: str
bits: list[str]
- aliases: dict[str, str]
+ names: dict[str, str]
@property
- def static_size(self) -> int | None:
+ def static_size(self) -> int:
+ if len(self.bits) <= 8:
+ return 1
+ if len(self.bits) <= 16:
+ return 2
+ if len(self.bits) <= 32:
+ return 4
+ if len(self.bits) <= 64:
+ return 8
return int((len(self.bits) + 7) / 8)
@@ -209,7 +217,7 @@ def parse_file(
bf = Bitfield()
bf.name = m.group("name")
bf.bits = int(m.group("size")) * [""]
- bf.aliases = {}
+ bf.names = {}
env[bf.name] = bf
prev = bf
elif m := re.fullmatch(re_bitfieldspec_bit, line):
@@ -232,10 +240,10 @@ def parse_file(
raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
if bf.bits[bit]:
raise ValueError(f"{bf.name}: bit {bit} already assigned")
- if name in bf.aliases:
+ if name in bf.names:
raise ValueError(f"{bf.name}: name {name} already assigned")
bf.bits[bit] = name
- bf.aliases[name] = ""
+ bf.names[name] = ""
elif m := re.fullmatch(re_bitfieldspec_alias, line):
if m.group("bitfield"):
if m.group("bitfield") not in env:
@@ -252,9 +260,9 @@ def parse_file(
bf = prev
name = m.group("name")
val = m.group("val")
- if name in bf.aliases:
+ if name in bf.names:
raise ValueError(f"{bf.name}: name {name} already assigned")
- bf.aliases[name] = val
+ bf.names[name] = val
else:
raise SyntaxError(f"invalid line {repr(line)}")
if not version:
@@ -348,10 +356,17 @@ enum {idprefix}version {{
for bf in just_bitfields(typs):
ret += "\n"
ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n"
- vals = dict([
- *reversed([((k or f"_UNUSED_{v}"), f"1<<{v}") for (v, k) in enumerate(bf.bits)]),
- *[(k, v) for (k, v) in bf.aliases.items() if v],
- ])
+ vals = dict(
+ [
+ *reversed(
+ [
+ ((k or f"_UNUSED_{v}"), f"1<<{v}")
+ for (v, k) in enumerate(bf.bits)
+ ]
+ ),
+ *[(k, v) for (k, v) in bf.names.items() if k not in bf.bits],
+ ]
+ )
namewidth = max(len(name) for name in vals)
for name, val in vals.items():
ret += f"#define {idprefix.upper()}{bf.name.upper()}_{name.ljust(namewidth)} (({c_typename(idprefix, bf)})({val}))\n"
@@ -505,17 +520,17 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
#define validate_4(ctx) _validate_size_net(ctx, 4)
#define validate_8(ctx) _validate_size_net(ctx, 8)
"""
- for struct in just_structs_all(typs):
- inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN"
- argfn = used if struct.members else unused
+ for typ in typs:
+ inline = (
+ " FLATTEN"
+ if (isinstance(typ, Struct) and typ.msgid is not None)
+ else " ALWAYS_INLINE"
+ )
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
ret += "\n"
- ret += f"static{inline} bool validate_{struct.name}(struct _validate_ctx *{argfn('ctx')}) {{"
- if len(struct.members) == 0:
- ret += "\n\treturn false;\n"
- ret += "}\n"
- continue
+ ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{"
- if struct.name == "d": # SPECIAL
+ if typ.name == "d": # SPECIAL
# Optimize... maybe the compiler could figure out to do
# this, but let's make it obvious.
ret += "\n"
@@ -526,7 +541,7 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
ret += "}\n"
continue
- if struct.name == "s": # SPECIAL
+ if typ.name == "s": # SPECIAL
# Add an extra nul-byte on the host, and validate UTF-8
# (also, similar optimization to "d").
ret += "\n"
@@ -542,27 +557,39 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
ret += "}\n"
continue
- prefix0 = "\treturn "
- prefix1 = "\t || "
-
- struct_versions = struct.members[0].ver
-
- prefix = prefix0
- prev_size: int | None = None
- for member in struct.members:
- ret += f"\n{prefix}"
- if member.ver != struct_versions:
- ret += "( " + c_vercond(idprefix, member.ver) + " && "
- if member.cnt is not None:
- assert prev_size
- ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
- else:
- ret += f"validate_{member.typ.name}(ctx)"
- if member.ver != struct_versions:
- ret += " )"
- prefix = prefix1
- prev_size = member.static_size
- ret += ";\n}\n"
+ match typ:
+ case Bitfield():
+ ret += "\n"
+ # ret += "\t{c_typename(idprefix, typ)} mask = 0b" + (''.join('1' if b else '0' for b in reversed(typ.bits))) + ";\n"
+ ret += f"\treturn validate_{typ.static_size}(ctx);\n"
+ case Struct():
+ if len(typ.members) == 0:
+ ret += "\n\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ prefix0 = "\treturn "
+ prefix1 = "\t || "
+
+ struct_versions = typ.members[0].ver
+
+ prefix = prefix0
+ prev_size: int | None = None
+ for member in typ.members:
+ ret += f"\n{prefix}"
+ if member.ver != struct_versions:
+ ret += "( " + c_vercond(idprefix, member.ver) + " && "
+ if member.cnt is not None:
+ assert prev_size
+ ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
+ else:
+ ret += f"validate_{member.typ.name}(ctx)"
+ if member.ver != struct_versions:
+ ret += " )"
+ prefix = prefix1
+ prev_size = member.static_size
+ ret += ";\n"
+ ret += "}\n"
# unmarshal_* ##############################################################
ret += """
@@ -588,32 +615,42 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out)
ctx->net_offset += 8;
}
"""
- for struct in just_structs_all(typs):
- inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN"
- argfn = used if struct.members else unused
+ for typ in typs:
+ inline = (
+ " FLATTEN"
+ if (isinstance(typ, Struct) and typ.msgid is not None)
+ else " ALWAYS_INLINE"
+ )
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
ret += "\n"
- ret += f"static{inline} void unmarshal_{struct.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *out) {{\n"
- ret += "\tmemset(out, 0, sizeof(*out));\n"
-
- if struct.members:
- struct_versions = struct.members[0].ver
- for member in struct.members:
- ret += "\t"
- prefix = "\t"
- if member.ver != struct_versions:
- ret += "if ( " + c_vercond(idprefix, member.ver) + " ) "
- prefix = "\t\t"
- if member.cnt:
- if member.ver != struct_versions:
- ret += "{\n"
- ret += f"{prefix}out->{member.name} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
- ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
- if member.ver != struct_versions:
- ret += "\t}\n"
- else:
- ret += f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
+ ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n"
+ match typ:
+ case Bitfield():
+ ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n"
+ case Struct():
+ ret += "\tmemset(out, 0, sizeof(*out));\n"
+
+ if typ.members:
+ struct_versions = typ.members[0].ver
+ for member in typ.members:
+ ret += "\t"
+ prefix = "\t"
+ if member.ver != struct_versions:
+ ret += "if ( " + c_vercond(idprefix, member.ver) + " ) "
+ prefix = "\t\t"
+ if member.cnt:
+ if member.ver != struct_versions:
+ ret += "{\n"
+ ret += f"{prefix}out->{member.name} = ctx->extra;\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
+ ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
+ if member.ver != struct_versions:
+ ret += "\t}\n"
+ else:
+ ret += (
+ f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
+ )
ret += "}\n"
# marshal_* ################################################################
@@ -660,39 +697,49 @@ static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
return false;
}
"""
- for struct in just_structs_all(typs):
- inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN"
- argfn = used if struct.members else unused
+ for typ in typs:
+ inline = (
+ " FLATTEN"
+ if (isinstance(typ, Struct) and typ.msgid is not None)
+ else " ALWAYS_INLINE"
+ )
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
ret += "\n"
- ret += f"static{inline} bool marshal_{struct.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *{argfn('val')}) {{"
- if len(struct.members) == 0:
- ret += "\n\treturn false;\n"
- ret += "}\n"
- continue
-
- prefix0 = "\treturn "
- prefix1 = "\t || "
- prefix2 = "\t "
-
- struct_versions = struct.members[0].ver
- prefix = prefix0
- for member in struct.members:
- ret += f"\n{prefix}"
- if member.ver != struct_versions:
- ret += "( " + c_vercond(idprefix, member.ver) + " && "
- if member.cnt:
- ret += "({"
- ret += f"\n{prefix2}\tbool err = false;"
- ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
- ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
- ret += f"\n{prefix2}\terr;"
- ret += f"\n{prefix2}}})"
- else:
- ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
- if member.ver != struct_versions:
- ret += " )"
- prefix = prefix1
- ret += ";\n}\n"
+ ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{"
+ match typ:
+ case Bitfield():
+ ret += "\n"
+ ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n"
+ case Struct():
+ if len(typ.members) == 0:
+ ret += "\n\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ prefix0 = "\treturn "
+ prefix1 = "\t || "
+ prefix2 = "\t "
+
+ struct_versions = typ.members[0].ver
+ prefix = prefix0
+ for member in typ.members:
+ ret += f"\n{prefix}"
+ if member.ver != struct_versions:
+ ret += "( " + c_vercond(idprefix, member.ver) + " && "
+ if member.cnt:
+ ret += "({"
+ ret += f"\n{prefix2}\tbool err = false;"
+ ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
+ ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
+ ret += f"\n{prefix2}\terr;"
+ ret += f"\n{prefix2}}})"
+ else:
+ ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
+ if member.ver != struct_versions:
+ ret += " )"
+ prefix = prefix1
+ ret += ";\n"
+ ret += "}\n"
# vtables ##################################################################
ret += f"""