summaryrefslogtreecommitdiff
path: root/lib9p
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-09-28 07:43:55 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-09-28 07:43:55 -0600
commitc8c10a3a7a89821202086392accaf4dd8c4a78e3 (patch)
treedbd5c61f3dd313fd55577fde835e64aede03de47 /lib9p
parent7ec97df3ee8edfd102fe573eaa61cf4e5c6284cb (diff)
wip
Diffstat (limited to 'lib9p')
-rw-r--r--lib9p/include/lib9p/9p.h5
-rw-r--r--lib9p/srv.c70
-rw-r--r--lib9p/types.c29
-rwxr-xr-xlib9p/types.gen60
4 files changed, 82 insertions, 82 deletions
diff --git a/lib9p/include/lib9p/9p.h b/lib9p/include/lib9p/9p.h
index 954498e..b33ae96 100644
--- a/lib9p/include/lib9p/9p.h
+++ b/lib9p/include/lib9p/9p.h
@@ -38,6 +38,9 @@ int lib9p_errorf(struct lib9p_ctx *ctx, uint32_t linux_errno, char const *fmt, .
* @param net_bytes : the complete request, starting with the "size[4]"
*
* @return required size, or -1 on error
+ *
+ * @errno LINUX_EBADMSG: message is too short for content
+ * @errno LINUX_EBADMSG: message contains invalid UTF-8
*/
ssize_t lib9p_unmarshal_size(struct lib9p_ctx *ctx, uint8_t *net_bytes);
@@ -67,6 +70,8 @@ void lib9p_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes,
*
* @return ret_bytes : the buffer to encode to, must be at be at least lib9p_ctx_max_msg_size(ctx) bytes
* @return whether there was an error (false=success, true=error)
+ *
+ * @errno LINUX_ERANGE: reply does not fit in ctx->max_msg_size
*/
bool lib9p_marshal(struct lib9p_ctx *ctx, enum lib9p_msg_type typ, uint16_t tag, void *body,
uint8_t *ret_bytes);
diff --git a/lib9p/srv.c b/lib9p/srv.c
index 9a68a2b..039b4c2 100644
--- a/lib9p/srv.c
+++ b/lib9p/srv.c
@@ -16,8 +16,7 @@ struct lib9p_srvconn {
cid_t reader;
int fd;
/* mutable */
- uint32_t max_msg_size;
- enum lib9p_version version;
+ struct lib9p_ctx ctx;
unsigned int refcount;
};
@@ -26,6 +25,19 @@ struct lib9p_srvreq {
uint8_t *msg;
};
+static void marshal_error(struct lib9p_ctx *ctx, uint16_t tag, uint8_t *net) {
+ struct lib9p_msg_Rerror host = {
+ .ename = {
+ .len = strnlen(ctx->err_msg, CONFIG_9P_MAX_ERR_SIZE),
+ .utf8 = (uint8_t*)ctx->err_msg,
+ },
+ .errno = ctx->err_num,
+ };
+ lib9p_marshal(ctx, LIB9P_TYP_Rerror, tag, &host, net);
+}
+
+void handle_message(struct lib9p_srvconn *conn, uint8_t *net);
+
COROUTINE lib9p_srv_read_cr(void *_srv) {
uint8_t buf[CONFIG_9P_MAX_MSG_SIZE];
@@ -37,9 +49,11 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
struct lib9p_srvconn conn = {
.srv = srv,
.reader = cr_getcid(),
-
- .max_msg_size = CONFIG_9P_MAX_MSG_SIZE,
- .version = LIB9P_VER_UNINITIALIZED,
+
+ .ctx = {
+ .version = LIB9P_VER_UNINITIALIZED,
+ .max_msg_size = CONFIG_9P_MAX_MSG_SIZE,
+ },
.refcount = 1,
};
conn.fd = netio_accept(srv->sockfd);
@@ -54,7 +68,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
while (done < goal) {
ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done);
if (r < 0) {
- fprintf(stderr, "error: read: %m", -r);
+ fprintf(stderr, "error: read: %s", strerror(-r));
goto close;
} else if (r == 0) {
if (done != 0)
@@ -69,16 +83,11 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
fprintf(stderr, "error: T-message is impossibly small");
goto close;
}
- if (goal > conn.max_msg_size) {
- struct lib9p_ctx ctx = {
- .version = conn.version,
- .max_msg_size = conn.max_msg_size,
- };
- if (initialized)
- lib9p_errorf(&ctx, LINUX_EMSGSIZE, "T-message larger than negotiated limit (%zu > %zu)", goal, conn.max_msg_size);
- else
- lib9p_errorf(&ctx, LINUX_EMSGSIZE, "T-message larger than server limit (%zu > %zu)", goal, conn.max_msg_size);
- marshal_error(&ctx, buf);
+ if (goal > conn.ctx.max_msg_size) {
+ lib9p_errorf(&conn.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %zu)",
+ conn.ctx.version ? "negotiated" : "server", goal, conn.ctx.max_msg_size);
+ uint16_t tag = decode_u16le(&buf[5]);
+ marshal_error(&conn.ctx, tag, buf);
netio_write(conn.fd, buf, decode_u32le(buf));
continue;
}
@@ -86,7 +95,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
while (done < goal) {
ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done);
if (r < 0) {
- fprintf(stderr, "error: read: %m", -r);
+ fprintf(stderr, "error: read: %s", strerror(-r));
goto close;
} else if (r == 0) {
fprintf(stderr, "error: read: unexpected EOF");
@@ -96,7 +105,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
}
/* Handle the message... */
- if (conn.version == LIB9P_VER_UNINITIALIZED) {
+ if (conn.ctx.version == LIB9P_VER_UNINITIALIZED) {
/* ...synchronously if we haven't negotiated the protocol yet, ... */
handle_message(&conn, buf);
} else {
@@ -140,7 +149,7 @@ COROUTINE lib9p_srv_write_cr(void *_srv) {
cr_end();
}
-void handle_message(lib9p_srvconn *conn, uint8_t *net) {
+void handle_message(struct lib9p_srvconn *conn, uint8_t *net) {
uint8_t host[CONFIG_9P_MAX_MSG_SIZE];
struct lib9p_ctx ctx = {
@@ -175,24 +184,5 @@ void handle_message(lib9p_srvconn *conn, uint8_t *net) {
netio_write(req.conn->fd, net, decode_u32le(net));
}
-static inline uint16_t min_u16(uint16_t a, b) {
- return (a < b) ? a : b;
-}
-
-/* We have special code for marshaling Rerror because we don't ever
- * want to produce an error because the err_msg is too long for the
- * `ctx->max_msg_size`! */
-void marshal_error(struct lib9p_ctx *ctx, uint16_t tag, uint8_t *net) {
- struct lib9p_msg_Rerror host = {
- .ename = {
- .len = strnlen(ctx->err_msg, CONFIG_9P_MAX_ERR_SIZE),
- .utf8 = ctx->err_msg,
- },
- };
- if (host.ename.len + ctx->Rerror_overhead > ctx->max_msg_size)
- host.ename.len = ctx->max_msg_size - overhead;
- lib9p_marshal(ctx, tag, host, net);
-}
-
-ERANGE for reply too large
-EPROTONOSUPPORT for version errors
+// EMSGSIZE for request too large
+// EPROTONOSUPPORT for version errors
diff --git a/lib9p/types.c b/lib9p/types.c
index 5b6d0a0..071e3c8 100644
--- a/lib9p/types.c
+++ b/lib9p/types.c
@@ -83,10 +83,10 @@ static inline bool checksize_stat(struct _checksize_ctx *ctx) {
|| checksize_s(ctx)
|| checksize_s(ctx)
|| checksize_s(ctx)
- || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_s(ctx) )
- || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) )
- || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) )
- || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) );
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_s(ctx) )
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) )
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) )
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) );
}
static bool checksize_Tversion(struct _checksize_ctx *ctx) {
@@ -103,7 +103,7 @@ static bool checksize_Tauth(struct _checksize_ctx *ctx) {
return checksize_4(ctx)
|| checksize_s(ctx)
|| checksize_s(ctx)
- || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) );
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) );
}
static bool checksize_Rauth(struct _checksize_ctx *ctx) {
@@ -123,7 +123,7 @@ static bool checksize_Rattach(struct _checksize_ctx *ctx) {
static bool checksize_Rerror(struct _checksize_ctx *ctx) {
return checksize_s(ctx)
- || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) );
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) );
}
static bool checksize_Tflush(struct _checksize_ctx *ctx) {
@@ -521,7 +521,7 @@ static void unmarshal_Rswrite(struct _unmarshal_ctx *ctx, struct lib9p_msg_Rswri
/* marshal_* ******************************************************************/
static inline bool _marshal_too_large(struct _marshal_ctx *ctx) {
- lib9p_errorf(ctx->ctx, LINUX_EMSGSIZE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
+ lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
ctx->ctx->max_msg_size);
@@ -599,10 +599,10 @@ static inline bool marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val
|| marshal_s(ctx, &val->file_owner_uid)
|| marshal_s(ctx, &val->file_owner_gid)
|| marshal_s(ctx, &val->file_last_modified_uid)
- || marshal_s(ctx, &val->file_extension)
- || marshal_4(ctx, &val->file_owner_n_uid)
- || marshal_4(ctx, &val->file_owner_n_gid)
- || marshal_4(ctx, &val->file_last_modified_n_uid);
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_s(ctx, &val->file_extension) )
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_owner_n_uid) )
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_owner_n_gid) )
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_last_modified_n_uid) );
}
static bool marshal_Tversion(struct _marshal_ctx *ctx, struct lib9p_msg_Tversion *val) {
@@ -619,7 +619,7 @@ static bool marshal_Tauth(struct _marshal_ctx *ctx, struct lib9p_msg_Tauth *val)
return marshal_4(ctx, &val->afid)
|| marshal_s(ctx, &val->uname)
|| marshal_s(ctx, &val->aname)
- || marshal_4(ctx, &val->n_uname);
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->n_uname) );
}
static bool marshal_Rauth(struct _marshal_ctx *ctx, struct lib9p_msg_Rauth *val) {
@@ -638,8 +638,11 @@ static bool marshal_Rattach(struct _marshal_ctx *ctx, struct lib9p_msg_Rattach *
}
static bool marshal_Rerror(struct _marshal_ctx *ctx, struct lib9p_msg_Rerror *val) {
+ /* Truncate the error-string if necessary to avoid returning ERANGE. */
+ if (((uint32_t)val->ename.len) + ctx->ctx->Rerror_overhead > ctx->ctx->max_msg_size)
+ val->ename.len = ctx->ctx->max_msg_size - ctx->ctx->Rerror_overhead;
return marshal_s(ctx, &val->ename)
- || marshal_4(ctx, &val->errno);
+ || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->errno) );
}
static bool marshal_Tflush(struct _marshal_ctx *ctx, struct lib9p_msg_Tflush *val) {
diff --git a/lib9p/types.gen b/lib9p/types.gen
index 48f8107..44eebbe 100755
--- a/lib9p/types.gen
+++ b/lib9p/types.gen
@@ -205,14 +205,22 @@ def c_typename(idprefix: str, typ: Atom | Struct) -> str:
raise ValueError(f"not a type: {typ.__class__.__name__}")
+def c_verenum(idprefix: str, ver: str) -> str:
+ return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
+
+
def c_vercomment(versions: set[str]) -> str | None:
if "9P2000" in versions:
return None
return "/* " + (", ".join(sorted(versions))) + " */"
-def c_ver(idprefix: str, ver: str) -> str:
- return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
+def c_vercond(idprefix: str, versions: set[str]) -> str:
+ if len(versions) == 1:
+ return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})"
+ return (
+ "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )"
+ )
def gen_h(idprefix: str, versions: set[str], structs: list[Struct]) -> str:
@@ -233,7 +241,7 @@ enum {idprefix}version {{
"""
verwidth = max(len(v) for v in versions)
for ver in sorted(versions):
- ret += f"\t{c_ver(idprefix, ver)},"
+ ret += f"\t{c_verenum(idprefix, ver)},"
ret += (" " * (verwidth - len(ver))) + ' /* "' + ver + '" */\n'
ret += f"\t{idprefix.upper()}VER_NUM,\n"
ret += "};\n"
@@ -396,7 +404,9 @@ static inline bool _checksize_list(struct _checksize_ctx *ctx,
ret += "\tif (checksize_4(ctx))\n"
ret += "\t\treturn true;\n"
ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
- ret += "\treturn _checksize_net(ctx, len) || _checksize_host(ctx, len);\n"
+ ret += (
+ "\treturn _checksize_net(ctx, len) || _checksize_host(ctx, len);\n"
+ )
ret += "}\n"
case "s":
# Add an extra nul-byte on the host, and validate
@@ -414,22 +424,12 @@ static inline bool _checksize_list(struct _checksize_ctx *ctx,
ret += "}\n"
case _:
struct_versions = struct.members[0].ver
-
prefix = prefix0
prev_size: int | None = None
for member in struct.members:
ret += f"\n{prefix}"
if member.ver != struct_versions:
- ret += (
- "( ( "
- + (
- " || ".join(
- f"(ctx->ctx->version=={c_ver(idprefix, v)})"
- for v in sorted(member.ver)
- )
- )
- + " ) && "
- )
+ ret += "( " + c_vercond(idprefix, member.ver) + " && "
if member.cnt is not None:
assert prev_size
ret += f"_checksize_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), checksize_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
@@ -478,16 +478,7 @@ static inline void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
ret += "\t"
prefix = "\t"
if member.ver != struct_versions:
- ret += (
- "if ( "
- + (
- " || ".join(
- f"(ctx->ctx->version=={c_ver(idprefix, v)})"
- for v in sorted(member.ver)
- )
- )
- + " ) "
- )
+ ret += "if ( " + c_vercond(idprefix, member.ver) + " ) "
prefix = "\t\t"
if member.cnt:
if member.ver != struct_versions:
@@ -507,7 +498,7 @@ static inline void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
/* marshal_* ******************************************************************/
static inline bool _marshal_too_large(struct _marshal_ctx *ctx) {
- lib9p_errorf(ctx->ctx, LINUX_EMSGSIZE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
+ lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
ctx->ctx->max_msg_size);
@@ -556,21 +547,32 @@ static inline bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
ret += "}\n"
continue
+ if struct.name == "Rerror":
+ ret += "\n\t/* Truncate the error-string if necessary to avoid returning ERANGE. */"
+ ret += "\n\tif (((uint32_t)val->ename.len) + ctx->ctx->Rerror_overhead > ctx->ctx->max_msg_size)"
+ ret += "\n\t\tval->ename.len = ctx->ctx->max_msg_size - ctx->ctx->Rerror_overhead;"
+
prefix0 = "\treturn "
prefix1 = "\t || "
prefix2 = "\t "
+ struct_versions = struct.members[0].ver
prefix = prefix0
for member in struct.members:
+ ret += f"\n{prefix}"
+ if member.ver != struct_versions:
+ ret += "( " + c_vercond(idprefix, member.ver) + " && "
if member.cnt:
- ret += f"\n{prefix }({{"
+ ret += "({"
ret += f"\n{prefix2}\tbool err = false;"
ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
ret += f"\n{prefix2}\terr;"
ret += f"\n{prefix2}}})"
else:
- ret += f"\n{prefix}marshal_{member.typ.name}(ctx, &val->{member.name})"
+ ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
+ if member.ver != struct_versions:
+ ret += " )"
prefix = prefix1
ret += ";\n}\n"
@@ -595,7 +597,7 @@ struct _vtable_version _{idprefix}vtables[LIB9P_VER_NUM] = {{
}}}},
"""
for ver in sorted(versions):
- ret += f"\t[{c_ver(idprefix, ver)}] = {{ .msgs = {{\n"
+ ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n"
for msg in structs:
if ver not in msg.msgver:
continue