diff options
-rw-r--r-- | lib9p/include/lib9p/9p.h | 5 | ||||
-rw-r--r-- | lib9p/srv.c | 70 | ||||
-rw-r--r-- | lib9p/types.c | 29 | ||||
-rwxr-xr-x | lib9p/types.gen | 60 |
4 files changed, 82 insertions, 82 deletions
diff --git a/lib9p/include/lib9p/9p.h b/lib9p/include/lib9p/9p.h index 954498e..b33ae96 100644 --- a/lib9p/include/lib9p/9p.h +++ b/lib9p/include/lib9p/9p.h @@ -38,6 +38,9 @@ int lib9p_errorf(struct lib9p_ctx *ctx, uint32_t linux_errno, char const *fmt, . * @param net_bytes : the complete request, starting with the "size[4]" * * @return required size, or -1 on error + * + * @errno LINUX_EBADMSG: message is too short for content + * @errno LINUX_EBADMSG: message contains invalid UTF-8 */ ssize_t lib9p_unmarshal_size(struct lib9p_ctx *ctx, uint8_t *net_bytes); @@ -67,6 +70,8 @@ void lib9p_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes, * * @return ret_bytes : the buffer to encode to, must be at be at least lib9p_ctx_max_msg_size(ctx) bytes * @return whether there was an error (false=success, true=error) + * + * @errno LINUX_ERANGE: reply does not fit in ctx->max_msg_size */ bool lib9p_marshal(struct lib9p_ctx *ctx, enum lib9p_msg_type typ, uint16_t tag, void *body, uint8_t *ret_bytes); diff --git a/lib9p/srv.c b/lib9p/srv.c index 9a68a2b..039b4c2 100644 --- a/lib9p/srv.c +++ b/lib9p/srv.c @@ -16,8 +16,7 @@ struct lib9p_srvconn { cid_t reader; int fd; /* mutable */ - uint32_t max_msg_size; - enum lib9p_version version; + struct lib9p_ctx ctx; unsigned int refcount; }; @@ -26,6 +25,19 @@ struct lib9p_srvreq { uint8_t *msg; }; +static void marshal_error(struct lib9p_ctx *ctx, uint16_t tag, uint8_t *net) { + struct lib9p_msg_Rerror host = { + .ename = { + .len = strnlen(ctx->err_msg, CONFIG_9P_MAX_ERR_SIZE), + .utf8 = (uint8_t*)ctx->err_msg, + }, + .errno = ctx->err_num, + }; + lib9p_marshal(ctx, LIB9P_TYP_Rerror, tag, &host, net); +} + +void handle_message(struct lib9p_srvconn *conn, uint8_t *net); + COROUTINE lib9p_srv_read_cr(void *_srv) { uint8_t buf[CONFIG_9P_MAX_MSG_SIZE]; @@ -37,9 +49,11 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { struct lib9p_srvconn conn = { .srv = srv, .reader = cr_getcid(), - - .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, - .version = LIB9P_VER_UNINITIALIZED, + + .ctx = { + .version = LIB9P_VER_UNINITIALIZED, + .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, + }, .refcount = 1, }; conn.fd = netio_accept(srv->sockfd); @@ -54,7 +68,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { while (done < goal) { ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done); if (r < 0) { - fprintf(stderr, "error: read: %m", -r); + fprintf(stderr, "error: read: %s", strerror(-r)); goto close; } else if (r == 0) { if (done != 0) @@ -69,16 +83,11 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { fprintf(stderr, "error: T-message is impossibly small"); goto close; } - if (goal > conn.max_msg_size) { - struct lib9p_ctx ctx = { - .version = conn.version, - .max_msg_size = conn.max_msg_size, - }; - if (initialized) - lib9p_errorf(&ctx, LINUX_EMSGSIZE, "T-message larger than negotiated limit (%zu > %zu)", goal, conn.max_msg_size); - else - lib9p_errorf(&ctx, LINUX_EMSGSIZE, "T-message larger than server limit (%zu > %zu)", goal, conn.max_msg_size); - marshal_error(&ctx, buf); + if (goal > conn.ctx.max_msg_size) { + lib9p_errorf(&conn.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %zu)", + conn.ctx.version ? "negotiated" : "server", goal, conn.ctx.max_msg_size); + uint16_t tag = decode_u16le(&buf[5]); + marshal_error(&conn.ctx, tag, buf); netio_write(conn.fd, buf, decode_u32le(buf)); continue; } @@ -86,7 +95,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { while (done < goal) { ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done); if (r < 0) { - fprintf(stderr, "error: read: %m", -r); + fprintf(stderr, "error: read: %s", strerror(-r)); goto close; } else if (r == 0) { fprintf(stderr, "error: read: unexpected EOF"); @@ -96,7 +105,7 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { } /* Handle the message... */ - if (conn.version == LIB9P_VER_UNINITIALIZED) { + if (conn.ctx.version == LIB9P_VER_UNINITIALIZED) { /* ...synchronously if we haven't negotiated the protocol yet, ... */ handle_message(&conn, buf); } else { @@ -140,7 +149,7 @@ COROUTINE lib9p_srv_write_cr(void *_srv) { cr_end(); } -void handle_message(lib9p_srvconn *conn, uint8_t *net) { +void handle_message(struct lib9p_srvconn *conn, uint8_t *net) { uint8_t host[CONFIG_9P_MAX_MSG_SIZE]; struct lib9p_ctx ctx = { @@ -175,24 +184,5 @@ void handle_message(lib9p_srvconn *conn, uint8_t *net) { netio_write(req.conn->fd, net, decode_u32le(net)); } -static inline uint16_t min_u16(uint16_t a, b) { - return (a < b) ? a : b; -} - -/* We have special code for marshaling Rerror because we don't ever - * want to produce an error because the err_msg is too long for the - * `ctx->max_msg_size`! */ -void marshal_error(struct lib9p_ctx *ctx, uint16_t tag, uint8_t *net) { - struct lib9p_msg_Rerror host = { - .ename = { - .len = strnlen(ctx->err_msg, CONFIG_9P_MAX_ERR_SIZE), - .utf8 = ctx->err_msg, - }, - }; - if (host.ename.len + ctx->Rerror_overhead > ctx->max_msg_size) - host.ename.len = ctx->max_msg_size - overhead; - lib9p_marshal(ctx, tag, host, net); -} - -ERANGE for reply too large -EPROTONOSUPPORT for version errors +// EMSGSIZE for request too large +// EPROTONOSUPPORT for version errors diff --git a/lib9p/types.c b/lib9p/types.c index 5b6d0a0..071e3c8 100644 --- a/lib9p/types.c +++ b/lib9p/types.c @@ -83,10 +83,10 @@ static inline bool checksize_stat(struct _checksize_ctx *ctx) { || checksize_s(ctx) || checksize_s(ctx) || checksize_s(ctx) - || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_s(ctx) ) - || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) ) - || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) ) - || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) ); + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_s(ctx) ) + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) ) + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) ) + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) ); } static bool checksize_Tversion(struct _checksize_ctx *ctx) { @@ -103,7 +103,7 @@ static bool checksize_Tauth(struct _checksize_ctx *ctx) { return checksize_4(ctx) || checksize_s(ctx) || checksize_s(ctx) - || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) ); + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) ); } static bool checksize_Rauth(struct _checksize_ctx *ctx) { @@ -123,7 +123,7 @@ static bool checksize_Rattach(struct _checksize_ctx *ctx) { static bool checksize_Rerror(struct _checksize_ctx *ctx) { return checksize_s(ctx) - || ( ( (ctx->ctx->version==LIB9P_VER_9P2000_u) ) && checksize_4(ctx) ); + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && checksize_4(ctx) ); } static bool checksize_Tflush(struct _checksize_ctx *ctx) { @@ -521,7 +521,7 @@ static void unmarshal_Rswrite(struct _unmarshal_ctx *ctx, struct lib9p_msg_Rswri /* marshal_* ******************************************************************/ static inline bool _marshal_too_large(struct _marshal_ctx *ctx) { - lib9p_errorf(ctx->ctx, LINUX_EMSGSIZE, "%s too large to marshal into %s limit (limit=%"PRIu32")", + lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), ctx->ctx->max_msg_size); @@ -599,10 +599,10 @@ static inline bool marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val || marshal_s(ctx, &val->file_owner_uid) || marshal_s(ctx, &val->file_owner_gid) || marshal_s(ctx, &val->file_last_modified_uid) - || marshal_s(ctx, &val->file_extension) - || marshal_4(ctx, &val->file_owner_n_uid) - || marshal_4(ctx, &val->file_owner_n_gid) - || marshal_4(ctx, &val->file_last_modified_n_uid); + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_s(ctx, &val->file_extension) ) + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_owner_n_uid) ) + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_owner_n_gid) ) + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->file_last_modified_n_uid) ); } static bool marshal_Tversion(struct _marshal_ctx *ctx, struct lib9p_msg_Tversion *val) { @@ -619,7 +619,7 @@ static bool marshal_Tauth(struct _marshal_ctx *ctx, struct lib9p_msg_Tauth *val) return marshal_4(ctx, &val->afid) || marshal_s(ctx, &val->uname) || marshal_s(ctx, &val->aname) - || marshal_4(ctx, &val->n_uname); + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->n_uname) ); } static bool marshal_Rauth(struct _marshal_ctx *ctx, struct lib9p_msg_Rauth *val) { @@ -638,8 +638,11 @@ static bool marshal_Rattach(struct _marshal_ctx *ctx, struct lib9p_msg_Rattach * } static bool marshal_Rerror(struct _marshal_ctx *ctx, struct lib9p_msg_Rerror *val) { + /* Truncate the error-string if necessary to avoid returning ERANGE. */ + if (((uint32_t)val->ename.len) + ctx->ctx->Rerror_overhead > ctx->ctx->max_msg_size) + val->ename.len = ctx->ctx->max_msg_size - ctx->ctx->Rerror_overhead; return marshal_s(ctx, &val->ename) - || marshal_4(ctx, &val->errno); + || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->errno) ); } static bool marshal_Tflush(struct _marshal_ctx *ctx, struct lib9p_msg_Tflush *val) { diff --git a/lib9p/types.gen b/lib9p/types.gen index 48f8107..44eebbe 100755 --- a/lib9p/types.gen +++ b/lib9p/types.gen @@ -205,14 +205,22 @@ def c_typename(idprefix: str, typ: Atom | Struct) -> str: raise ValueError(f"not a type: {typ.__class__.__name__}") +def c_verenum(idprefix: str, ver: str) -> str: + return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" + + def c_vercomment(versions: set[str]) -> str | None: if "9P2000" in versions: return None return "/* " + (", ".join(sorted(versions))) + " */" -def c_ver(idprefix: str, ver: str) -> str: - return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" +def c_vercond(idprefix: str, versions: set[str]) -> str: + if len(versions) == 1: + return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})" + return ( + "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )" + ) def gen_h(idprefix: str, versions: set[str], structs: list[Struct]) -> str: @@ -233,7 +241,7 @@ enum {idprefix}version {{ """ verwidth = max(len(v) for v in versions) for ver in sorted(versions): - ret += f"\t{c_ver(idprefix, ver)}," + ret += f"\t{c_verenum(idprefix, ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver + '" */\n' ret += f"\t{idprefix.upper()}VER_NUM,\n" ret += "};\n" @@ -396,7 +404,9 @@ static inline bool _checksize_list(struct _checksize_ctx *ctx, ret += "\tif (checksize_4(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" - ret += "\treturn _checksize_net(ctx, len) || _checksize_host(ctx, len);\n" + ret += ( + "\treturn _checksize_net(ctx, len) || _checksize_host(ctx, len);\n" + ) ret += "}\n" case "s": # Add an extra nul-byte on the host, and validate @@ -414,22 +424,12 @@ static inline bool _checksize_list(struct _checksize_ctx *ctx, ret += "}\n" case _: struct_versions = struct.members[0].ver - prefix = prefix0 prev_size: int | None = None for member in struct.members: ret += f"\n{prefix}" if member.ver != struct_versions: - ret += ( - "( ( " - + ( - " || ".join( - f"(ctx->ctx->version=={c_ver(idprefix, v)})" - for v in sorted(member.ver) - ) - ) - + " ) && " - ) + ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt is not None: assert prev_size ret += f"_checksize_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), checksize_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" @@ -478,16 +478,7 @@ static inline void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { ret += "\t" prefix = "\t" if member.ver != struct_versions: - ret += ( - "if ( " - + ( - " || ".join( - f"(ctx->ctx->version=={c_ver(idprefix, v)})" - for v in sorted(member.ver) - ) - ) - + " ) " - ) + ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " prefix = "\t\t" if member.cnt: if member.ver != struct_versions: @@ -507,7 +498,7 @@ static inline void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { /* marshal_* ******************************************************************/ static inline bool _marshal_too_large(struct _marshal_ctx *ctx) { - lib9p_errorf(ctx->ctx, LINUX_EMSGSIZE, "%s too large to marshal into %s limit (limit=%"PRIu32")", + lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), ctx->ctx->max_msg_size); @@ -556,21 +547,32 @@ static inline bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { ret += "}\n" continue + if struct.name == "Rerror": + ret += "\n\t/* Truncate the error-string if necessary to avoid returning ERANGE. */" + ret += "\n\tif (((uint32_t)val->ename.len) + ctx->ctx->Rerror_overhead > ctx->ctx->max_msg_size)" + ret += "\n\t\tval->ename.len = ctx->ctx->max_msg_size - ctx->ctx->Rerror_overhead;" + prefix0 = "\treturn " prefix1 = "\t || " prefix2 = "\t " + struct_versions = struct.members[0].ver prefix = prefix0 for member in struct.members: + ret += f"\n{prefix}" + if member.ver != struct_versions: + ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt: - ret += f"\n{prefix }({{" + ret += "({" ret += f"\n{prefix2}\tbool err = false;" ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" ret += f"\n{prefix2}\terr;" ret += f"\n{prefix2}}})" else: - ret += f"\n{prefix}marshal_{member.typ.name}(ctx, &val->{member.name})" + ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + if member.ver != struct_versions: + ret += " )" prefix = prefix1 ret += ";\n}\n" @@ -595,7 +597,7 @@ struct _vtable_version _{idprefix}vtables[LIB9P_VER_NUM] = {{ }}}}, """ for ver in sorted(versions): - ret += f"\t[{c_ver(idprefix, ver)}] = {{ .msgs = {{\n" + ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n" for msg in structs: if ver not in msg.msgver: continue |