From a24a60232204702fe245c312edb0c2c8041b17a8 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Mon, 13 Jan 2025 23:20:13 -0700 Subject: lib9p: Rewrite the marshalers to support zero-copy for data --- lib9p/srv.c | 71 +++++++++++++++++++++++-------------------------------------- 1 file changed, 27 insertions(+), 44 deletions(-) (limited to 'lib9p/srv.c') diff --git a/lib9p/srv.c b/lib9p/srv.c index 9837994..61b40ea 100644 --- a/lib9p/srv.c +++ b/lib9p/srv.c @@ -131,8 +131,8 @@ struct _lib9p_srv_req { /* immutable */ struct _srv_sess *parent_sess; uint16_t tag; - /* mutable */ uint8_t *net_bytes; /* CONFIG_9P_MAX_MSG_SIZE-sized */ + /* mutable */ struct lib9p_srv_ctx ctx; }; @@ -140,31 +140,18 @@ struct _lib9p_srv_req { #define nonrespond_errorf errorf -static uint32_t rerror_overhead_for_version(enum lib9p_version version, - uint8_t *scratch) { - struct lib9p_ctx empty_ctx = { - .version = version, - .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, - }; - struct lib9p_msg_Rerror empty_error = { 0 }; - bool e; - - e = lib9p_Rmsg_marshal(&empty_ctx, LIB9P_TYP_Rerror, - &empty_error, /* host_body */ - scratch); /* net_bytes */ - assert(!e); - - uint32_t min_msg_size = uint32le_decode(scratch); - - /* Assert that min_msg_size + biggest_possible_MAX_ERR_SIZE - * won't overflow uint32... because using - * __builtin_add_overflow in respond_error() would be a bit - * much. */ - assert(min_msg_size < (UINT32_MAX - UINT16_MAX)); - /* Assert that min_msg_size doesn't overflow MAX_MSG_SIZE. */ - assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size); - - return min_msg_size; +static ssize_t write_Rmsg(struct _lib9p_srv_req *req, struct lib9p_Rmsg_send_buf *resp) { + ssize_t r = 0, _r; + cr_mutex_lock(&req->parent_sess->parent_conn->writelock); + for (size_t i = 0; i < resp->iov_cnt; i++) { + _r = LO_CALL(req->parent_sess->parent_conn->fd, write, + resp->iov[i].iov_base, resp->iov[i].iov_len); + if (_r < 0) + return _r; + r += _r; + } + cr_mutex_unlock(&req->parent_sess->parent_conn->writelock); + return r; } static void respond_error(struct _lib9p_srv_req *req) { @@ -186,19 +173,17 @@ static void respond_error(struct _lib9p_srv_req *req) { struct _srv_sess *sess = req->parent_sess; /* Truncate the error-string if necessary to avoid needing to - * return LINUX_ERANGE. The assert() in - * rerror_overhead_for_version() has checked that this - * addition doesn't overflow. */ + * return LINUX_ERANGE. */ if (((uint32_t)host.ename.len) + sess->rerror_overhead > sess->max_msg_size) host.ename.len = sess->max_msg_size - sess->rerror_overhead; - lib9p_Rmsg_marshal(&req->ctx.basectx, LIB9P_TYP_Rerror, - &host, req->net_bytes); + struct lib9p_Rmsg_send_buf net; - cr_mutex_lock(&sess->parent_conn->writelock); - r = LO_CALL(sess->parent_conn->fd, write, - req->net_bytes, uint32le_decode(req->net_bytes)); - cr_mutex_unlock(&sess->parent_conn->writelock); + lib9p_Rmsg_marshal(&req->ctx.basectx, + LIB9P_TYP_Rerror, &host, + &net); + + r = write_Rmsg(req, &net); if (r < 0) nonrespond_errorf("write: %s", net_strerror(-r)); } @@ -235,7 +220,7 @@ static void handle_message(struct _lib9p_srv_req *ctx); srv->readers++; - uint32_t initial_rerror_overhead = rerror_overhead_for_version(0, buf); + uint32_t initial_rerror_overhead = _lib9p_table_msg_min_size[LIB9P_VER_unknown]; for (;;) { struct _srv_conn conn = { @@ -432,14 +417,12 @@ static void handle_message(struct _lib9p_srv_req *ctx) { if (lib9p_ctx_has_error(&ctx->ctx.basectx)) respond_error(ctx); else { - if (lib9p_Rmsg_marshal(&ctx->ctx.basectx, typ+1, host_resp, - ctx->net_bytes)) + struct lib9p_Rmsg_send_buf net_resp; + if (lib9p_Rmsg_marshal(&ctx->ctx.basectx, + typ+1, host_resp, + &net_resp)) goto write; - - cr_mutex_lock(&ctx->parent_sess->parent_conn->writelock); - LO_CALL(ctx->parent_sess->parent_conn->fd, write, - ctx->net_bytes, uint32le_decode(ctx->net_bytes)); - cr_mutex_unlock(&ctx->parent_sess->parent_conn->writelock); + write_Rmsg(ctx, &net_resp); } } @@ -583,7 +566,7 @@ static void handle_Tversion(struct _lib9p_srv_req *ctx, #endif } - uint32_t min_msg_size = rerror_overhead_for_version(version, ctx->net_bytes); + uint32_t min_msg_size = _lib9p_table_msg_min_size[version]; if (req->max_msg_size < min_msg_size) { lib9p_errorf(&ctx->ctx.basectx, LINUX_EDOM, "requested max_msg_size is less than minimum for %s (%"PRIu32" < %"PRIu32")", -- cgit v1.2.3-2-g168b From 6759a952978ea011dbc08a13b8f97a7c97572d16 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Wed, 12 Feb 2025 20:09:19 -0700 Subject: lib9p: srv: Dynamically allocate read buffers --- lib9p/srv.c | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) (limited to 'lib9p/srv.c') diff --git a/lib9p/srv.c b/lib9p/srv.c index 61b40ea..c624fa8 100644 --- a/lib9p/srv.c +++ b/lib9p/srv.c @@ -131,7 +131,7 @@ struct _lib9p_srv_req { /* immutable */ struct _srv_sess *parent_sess; uint16_t tag; - uint8_t *net_bytes; /* CONFIG_9P_MAX_MSG_SIZE-sized */ + uint8_t *net_bytes; /* mutable */ struct lib9p_srv_ctx ctx; }; @@ -212,8 +212,6 @@ static bool read_exactly(lo_interface net_stream_conn fd, uint8_t *buf, size_t g static void handle_message(struct _lib9p_srv_req *ctx); [[noreturn]] void lib9p_srv_read_cr(struct lib9p_srv *srv, lo_interface net_stream_listener listener) { - uint8_t buf[CONFIG_9P_MAX_MSG_SIZE]; - assert(srv); assert(srv->rootdir); assert(!LO_IS_NULL(listener)); @@ -248,6 +246,7 @@ static void handle_message(struct _lib9p_srv_req *ctx); nextmsg: /* Read the message. */ size_t done = 0; + uint8_t buf[7]; if (read_exactly(conn.fd, buf, 4, &done)) goto close; size_t goal = uint32le_decode(buf); @@ -277,11 +276,16 @@ static void handle_message(struct _lib9p_srv_req *ctx); respond_error(&req); goto nextmsg; } - if (read_exactly(conn.fd, buf, goal, &done)) + req.net_bytes = malloc(goal); + assert(req.net_bytes); + memcpy(req.net_bytes, buf, done); + if (read_exactly(conn.fd, req.net_bytes, goal, &done)) { + free(req.net_bytes); goto close; + } /* Handle the message... */ - if (buf[4] == LIB9P_TYP_Tversion) + if (req.net_bytes[4] == LIB9P_TYP_Tversion) /* ...in this coroutine for Tversion, */ handle_message(&req); else @@ -302,7 +306,6 @@ static void handle_message(struct _lib9p_srv_req *ctx); /* write coroutine ************************************************************/ COROUTINE lib9p_srv_write_cr(void *_srv) { - uint8_t net[CONFIG_9P_MAX_MSG_SIZE]; struct _lib9p_srv_req req; _lib9p_srv_reqch_req_t rpc_handle; @@ -321,11 +324,9 @@ COROUTINE lib9p_srv_write_cr(void *_srv) { _lib9p_srv_reqch_send_resp(rpc_handle, 0); cr_exit(); } - /* Deep-copy the request from the reader coroutine's + /* Copy the request from the reader coroutine's * stack to our stack. */ req = *rpc_handle.req; - memcpy(net, req.net_bytes, uint32le_decode(req.net_bytes)); - req.net_bytes = net; /* Record that we have it. */ reqmap_store(&req.parent_sess->reqs, req.tag, &req); /* Notify the reader coroutine that we're done with @@ -393,19 +394,15 @@ static tmessage_handler tmessage_handlers[0x100] = { }; static void handle_message(struct _lib9p_srv_req *ctx) { - uint8_t host_req[CONFIG_9P_MAX_HOSTMSG_SIZE]; + uint8_t *host_req = NULL; uint8_t host_resp[CONFIG_9P_MAX_HOSTMSG_SIZE]; /* Unmarshal it. */ ssize_t host_size = lib9p_Tmsg_validate(&ctx->ctx.basectx, ctx->net_bytes); if (host_size < 0) goto write; - if ((size_t)host_size > sizeof(host_req)) { - lib9p_errorf(&ctx->ctx.basectx, - LINUX_EMSGSIZE, "unmarshalled payload larger than server limit (%zu > %zu)", - host_size, sizeof(host_req)); - goto write; - } + host_req = malloc(host_size); + assert(host_req); enum lib9p_msg_type typ; lib9p_Tmsg_unmarshal(&ctx->ctx.basectx, ctx->net_bytes, &typ, host_req); @@ -424,6 +421,9 @@ static void handle_message(struct _lib9p_srv_req *ctx) { goto write; write_Rmsg(ctx, &net_resp); } + if (host_req) + free(host_req); + free(ctx->net_bytes); } #define util_handler_common(ctx, req, resp) do { \ -- cgit v1.2.3-2-g168b