From a24a60232204702fe245c312edb0c2c8041b17a8 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Mon, 13 Jan 2025 23:20:13 -0700 Subject: lib9p: Rewrite the marshalers to support zero-copy for data --- lib9p/srv.c | 71 +++++++++++++++++++++++-------------------------------------- 1 file changed, 27 insertions(+), 44 deletions(-) (limited to 'lib9p/srv.c') diff --git a/lib9p/srv.c b/lib9p/srv.c index 9837994..61b40ea 100644 --- a/lib9p/srv.c +++ b/lib9p/srv.c @@ -131,8 +131,8 @@ struct _lib9p_srv_req { /* immutable */ struct _srv_sess *parent_sess; uint16_t tag; - /* mutable */ uint8_t *net_bytes; /* CONFIG_9P_MAX_MSG_SIZE-sized */ + /* mutable */ struct lib9p_srv_ctx ctx; }; @@ -140,31 +140,18 @@ struct _lib9p_srv_req { #define nonrespond_errorf errorf -static uint32_t rerror_overhead_for_version(enum lib9p_version version, - uint8_t *scratch) { - struct lib9p_ctx empty_ctx = { - .version = version, - .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, - }; - struct lib9p_msg_Rerror empty_error = { 0 }; - bool e; - - e = lib9p_Rmsg_marshal(&empty_ctx, LIB9P_TYP_Rerror, - &empty_error, /* host_body */ - scratch); /* net_bytes */ - assert(!e); - - uint32_t min_msg_size = uint32le_decode(scratch); - - /* Assert that min_msg_size + biggest_possible_MAX_ERR_SIZE - * won't overflow uint32... because using - * __builtin_add_overflow in respond_error() would be a bit - * much. */ - assert(min_msg_size < (UINT32_MAX - UINT16_MAX)); - /* Assert that min_msg_size doesn't overflow MAX_MSG_SIZE. */ - assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size); - - return min_msg_size; +static ssize_t write_Rmsg(struct _lib9p_srv_req *req, struct lib9p_Rmsg_send_buf *resp) { + ssize_t r = 0, _r; + cr_mutex_lock(&req->parent_sess->parent_conn->writelock); + for (size_t i = 0; i < resp->iov_cnt; i++) { + _r = LO_CALL(req->parent_sess->parent_conn->fd, write, + resp->iov[i].iov_base, resp->iov[i].iov_len); + if (_r < 0) + return _r; + r += _r; + } + cr_mutex_unlock(&req->parent_sess->parent_conn->writelock); + return r; } static void respond_error(struct _lib9p_srv_req *req) { @@ -186,19 +173,17 @@ static void respond_error(struct _lib9p_srv_req *req) { struct _srv_sess *sess = req->parent_sess; /* Truncate the error-string if necessary to avoid needing to - * return LINUX_ERANGE. The assert() in - * rerror_overhead_for_version() has checked that this - * addition doesn't overflow. */ + * return LINUX_ERANGE. */ if (((uint32_t)host.ename.len) + sess->rerror_overhead > sess->max_msg_size) host.ename.len = sess->max_msg_size - sess->rerror_overhead; - lib9p_Rmsg_marshal(&req->ctx.basectx, LIB9P_TYP_Rerror, - &host, req->net_bytes); + struct lib9p_Rmsg_send_buf net; - cr_mutex_lock(&sess->parent_conn->writelock); - r = LO_CALL(sess->parent_conn->fd, write, - req->net_bytes, uint32le_decode(req->net_bytes)); - cr_mutex_unlock(&sess->parent_conn->writelock); + lib9p_Rmsg_marshal(&req->ctx.basectx, + LIB9P_TYP_Rerror, &host, + &net); + + r = write_Rmsg(req, &net); if (r < 0) nonrespond_errorf("write: %s", net_strerror(-r)); } @@ -235,7 +220,7 @@ static void handle_message(struct _lib9p_srv_req *ctx); srv->readers++; - uint32_t initial_rerror_overhead = rerror_overhead_for_version(0, buf); + uint32_t initial_rerror_overhead = _lib9p_table_msg_min_size[LIB9P_VER_unknown]; for (;;) { struct _srv_conn conn = { @@ -432,14 +417,12 @@ static void handle_message(struct _lib9p_srv_req *ctx) { if (lib9p_ctx_has_error(&ctx->ctx.basectx)) respond_error(ctx); else { - if (lib9p_Rmsg_marshal(&ctx->ctx.basectx, typ+1, host_resp, - ctx->net_bytes)) + struct lib9p_Rmsg_send_buf net_resp; + if (lib9p_Rmsg_marshal(&ctx->ctx.basectx, + typ+1, host_resp, + &net_resp)) goto write; - - cr_mutex_lock(&ctx->parent_sess->parent_conn->writelock); - LO_CALL(ctx->parent_sess->parent_conn->fd, write, - ctx->net_bytes, uint32le_decode(ctx->net_bytes)); - cr_mutex_unlock(&ctx->parent_sess->parent_conn->writelock); + write_Rmsg(ctx, &net_resp); } } @@ -583,7 +566,7 @@ static void handle_Tversion(struct _lib9p_srv_req *ctx, #endif } - uint32_t min_msg_size = rerror_overhead_for_version(version, ctx->net_bytes); + uint32_t min_msg_size = _lib9p_table_msg_min_size[version]; if (req->max_msg_size < min_msg_size) { lib9p_errorf(&ctx->ctx.basectx, LINUX_EDOM, "requested max_msg_size is less than minimum for %s (%"PRIu32" < %"PRIu32")", -- cgit v1.2.3-2-g168b