diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-09-28 23:23:49 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-09-28 23:23:49 -0600 |
commit | a5ff9ff765d3e14d01099fade6b94624bb8de22b (patch) | |
tree | 94b7dc9a194393e4ca0265d3e5ac56d3e9335a57 /lib9p | |
parent | f410026b7bc96dbb42fec3839dc5d2e41b12f4a4 (diff) |
it compiles!
Diffstat (limited to 'lib9p')
-rw-r--r-- | lib9p/9p.c | 3 | ||||
-rw-r--r-- | lib9p/include/lib9p/srv.h | 5 | ||||
-rw-r--r-- | lib9p/internal.h | 4 | ||||
-rw-r--r-- | lib9p/srv.c | 368 | ||||
-rw-r--r-- | lib9p/types.c | 5 | ||||
-rwxr-xr-x | lib9p/types.gen | 7 |
6 files changed, 244 insertions, 148 deletions
@@ -62,7 +62,8 @@ ssize_t lib9p_unmarshal_size(struct lib9p_ctx *ctx, uint8_t *net_bytes) { /* Body */ struct _vtable_msg vtable = _lib9p_vtables[ctx->version].msgs[typ]; if (!vtable.unmarshal_extrasize) - return lib9p_errorf(ctx, LINUX_EOPNOTSUPP, "unknown message type %"PRIu8, typ); + return lib9p_errorf(ctx, LINUX_EOPNOTSUPP, "unknown message type %s", + lib9p_msg_type_str(typ)); if (vtable.unmarshal_extrasize(&subctx)) return -1; diff --git a/lib9p/include/lib9p/srv.h b/lib9p/include/lib9p/srv.h index 3ca8b4d..9220bd9 100644 --- a/lib9p/include/lib9p/srv.h +++ b/lib9p/include/lib9p/srv.h @@ -4,12 +4,11 @@ #include <libcr/coroutine.h> #include <libcr_ipc/chan.h> -struct lib9p_srv_req; +struct lib9p_req; struct lib9p_srv { int sockfd; - - cr_chan_t(struct lib9p_srv_req *) reqch; + cr_chan_t(struct lib9p_req *) reqch; }; /** diff --git a/lib9p/internal.h b/lib9p/internal.h index 08bb462..f67735b 100644 --- a/lib9p/internal.h +++ b/lib9p/internal.h @@ -14,9 +14,11 @@ #include <lib9p/9p.h> #define USE_CONFIG_9P +#define USE_CONFIG_COROUTINE #include "config.h" static_assert(CONFIG_9P_MAX_ERR_SIZE <= UINT16_MAX); static_assert(CONFIG_9P_MAX_MSG_SIZE <= SSIZE_MAX); +static_assert(CONFIG_9P_MAX_ERR_SIZE + CONFIG_9P_MAX_MSG_SIZE + 2*CONFIG_9P_MAX_HOSTMSG_SIZE < CONFIG_COROUTINE_DEFAULT_STACK_SIZE); /* C language *****************************************************************/ @@ -31,8 +33,6 @@ struct lib9p_ctx { /* negotiated */ enum lib9p_version version; uint32_t max_msg_size; - /* negotiated (server) */ - uint32_t Rerror_overhead; /* state */ uint32_t err_num; diff --git a/lib9p/srv.c b/lib9p/srv.c index 74c2014..997d137 100644 --- a/lib9p/srv.c +++ b/lib9p/srv.c @@ -1,48 +1,119 @@ #include <assert.h> -#include <stdio.h> /* for fprintf(), stderr */ -#include <string.h> /* for strerror() */ +#include <inttypes.h> /* for PRI* */ +#include <stdio.h> /* for fprintf(), stderr */ +#include <string.h> /* for strerror() */ #include <libcr/coroutine.h> #include <libcr_ipc/chan.h> +#include <libcr_ipc/mutex.h> #include <libnetio/netio.h> #include <lib9p/9p.h> #include <lib9p/srv.h> #include "internal.h" +/* structs ********************************************************************/ + /* The hierarchy of concepts is: * * server -> connection -> session -> request * */ -struct lib9p_srv_sess { +/* struct lib9p_srv {} is defined in <lib9p/srv.h> */ + +struct lib9p_conn { /* immutable */ - struct lib9p_srv *srv; - int connfd; - cid_t reader; /* the lib9p_srv_read_cr() coroutine for this session */ + struct lib9p_srv *parent_srv; + int fd; + cid_t reader; /* the lib9p_srv_read_cr() coroutine */ /* mutable */ - struct lib9p_ctx ctx; + cr_mutex_t writelock; +}; + +struct lib9p_sess { + /* immutable */ + struct lib9p_conn *parent_conn; + enum lib9p_version version; + uint32_t max_msg_size; + uint32_t rerror_overhead; + /* mutable */ + bool initialized; unsigned int refcount; }; -struct lib9p_srv_req { - struct lib9p_srv_sess *sess; - uint8_t *msg; +struct lib9p_req { + /* immutable */ + struct lib9p_sess *parent_sess; + uint16_t tag; + /* mutable */ + uint8_t *net_bytes; /* CONFIG_9P_MAX_MSG_SIZE-sized */ + struct lib9p_ctx ctx; }; -static void marshal_error(struct lib9p_ctx *ctx, uint16_t tag, uint8_t *net) { +/* base utilities *************************************************************/ + +#define nonrespond_errorf(...) fprintf(stderr, "error: " __VA_ARGS__) + +static uint32_t rerror_overhead_for_version(enum lib9p_version version, uint8_t *scratch) { + struct lib9p_ctx empty_ctx = { + .version = version, + .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, + }; + struct lib9p_msg_Rerror empty_error = { 0 }; + + assert(!lib9p_marshal(&empty_ctx, LIB9P_TYP_Rerror, 0, &empty_error, scratch)); + uint32_t min_msg_size = decode_u32le(scratch); + assert(min_msg_size < (UINT32_MAX - UINT16_MAX)); + assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size); + + return min_msg_size; +} + +static void respond_error(struct lib9p_req *req) { + assert(req->ctx.err_num); + assert(req->ctx.err_msg[0]); + + ssize_t r; struct lib9p_msg_Rerror host = { .ename = { - .len = strnlen(ctx->err_msg, CONFIG_9P_MAX_ERR_SIZE), - .utf8 = (uint8_t*)ctx->err_msg, + .len = strnlen(req->ctx.err_msg, CONFIG_9P_MAX_ERR_SIZE), + .utf8 = (uint8_t*)req->ctx.err_msg, }, - .errno = ctx->err_num, + .errno = req->ctx.err_num, }; - lib9p_marshal(ctx, LIB9P_TYP_Rerror, tag, &host, net); + + /* Truncate the error-string if necessary to avoid needing to return ERANGE. */ + if (((uint32_t)host.ename.len) + req->parent_sess->rerror_overhead > req->parent_sess->max_msg_size) + host.ename.len = req->parent_sess->max_msg_size - req->parent_sess->rerror_overhead; + + lib9p_marshal(&req->ctx, LIB9P_TYP_Rerror, req->tag, &host, req->net_bytes); + + cr_mutex_lock(&req->parent_sess->parent_conn->writelock); + r = netio_write(req->parent_sess->parent_conn->fd, req->net_bytes, decode_u32le(req->net_bytes)); + cr_mutex_unlock(&req->parent_sess->parent_conn->writelock); + if (r < 0) + nonrespond_errorf("write: %s", strerror(-r)); } -void handle_message(struct lib9p_srvconn *conn, uint8_t *net); +/* read coroutine *************************************************************/ + +static bool read_at_least(int fd, uint8_t *buf, size_t goal, size_t *done) { + while (*done < goal) { + ssize_t r = netio_read(fd, &buf[*done], CONFIG_9P_MAX_MSG_SIZE - *done); + if (r < 0) { + nonrespond_errorf("read: %s", strerror(-r)); + return true; + } else if (r == 0) { + if (*done != 0) { + nonrespond_errorf("read: unexpected EOF"); + return true; + } + *done += r; + } + } + return false; +} COROUTINE lib9p_srv_read_cr(void *_srv) { uint8_t buf[CONFIG_9P_MAX_MSG_SIZE]; @@ -51,80 +122,67 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { assert(srv); cr_begin(); + uint32_t initial_rerror_overhead = rerror_overhead_for_version(0, buf); + for (;;) { - struct lib9p_srvconn conn = { - .srv = srv, - .reader = cr_getcid(), - - .ctx = { - .version = LIB9P_VER_UNINITIALIZED, - .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, - }, - .refcount = 1, + struct lib9p_conn conn = { + .parent_srv = srv, + .fd = netio_accept(srv->sockfd), + .reader = cr_getcid(), }; - conn.fd = netio_accept(srv->sockfd); if (conn.fd < 0) { - fprintf(stderr, "error: accept: %s", strerror(-conn.fd)); + nonrespond_errorf("accept: %s", strerror(-conn.fd)); continue; } + struct lib9p_sess sess = { + .parent_conn = &conn, + .version = 0, + .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, + .rerror_overhead = initial_rerror_overhead, + .initialized = false, + .refcount = 1, + }; for (;;) { - /* Read the message size. */ - size_t goal = 4, done = 0; - while (done < goal) { - ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done); - if (r < 0) { - fprintf(stderr, "error: read: %s", strerror(-r)); - goto close; - } else if (r == 0) { - if (done != 0) - fprintf(stderr, "error: read: unexpected EOF"); - goto close; - } - done += r; - } - goal = decode_u32le(buf); + /* Read the message. */ + size_t done = 0; + if (read_at_least(conn.fd, buf, 4, &done)) + goto close; + size_t goal = decode_u32le(buf); if (goal < 7) { - /* We can't even respond with an Rerror becuase we wouldn't know what tag to use! */ - fprintf(stderr, "error: T-message is impossibly small"); + nonrespond_errorf("T-message is impossibly small"); goto close; } - if (goal > conn.ctx.max_msg_size) { - lib9p_errorf(&conn.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %zu)", - conn.ctx.version ? "negotiated" : "server", goal, conn.ctx.max_msg_size); - uint16_t tag = decode_u16le(&buf[5]); - marshal_error(&conn.ctx, tag, buf); - netio_write(conn.fd, buf, decode_u32le(buf)); + if (read_at_least(conn.fd, buf, 7, &done)) + goto close; + struct lib9p_req req = { + .parent_sess = &sess, + .tag = decode_u16le(&buf[5]), + .net_bytes = buf, + .ctx = { + .version = sess.version, + .max_msg_size = sess.max_msg_size, + }, + }; + if (goal > sess.max_msg_size) { + lib9p_errorf(&req.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %"PRIu32")", + sess.initialized ? "negotiated" : "server", goal, sess.max_msg_size); + respond_error(&req); continue; } - /* Read the rest of the message. */ - while (done < goal) { - ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done); - if (r < 0) { - fprintf(stderr, "error: read: %s", strerror(-r)); - goto close; - } else if (r == 0) { - fprintf(stderr, "error: read: unexpected EOF"); - goto close; - } - done += r; - } + if (read_at_least(conn.fd, buf, goal, &done)) + goto close; - /* Handle the message... */ - if (conn.ctx.version == LIB9P_VER_UNINITIALIZED) { - /* ...synchronously if we haven't negotiated the protocol yet, ... */ - handle_message(&conn, buf); - } else { - /* ...asynchronously if we have. */ - cr_chan_send(&srv->reqch, buf); - cr_pause_and_yield(); - } + /* Handle the message... in another coroutine. */ + sess.refcount++; + cr_chan_send(&srv->reqch, &req); + cr_pause_and_yield(); /* wait for it to have copied req */ } close: - netio_close(conn.fd, true, (--conn.refcount) == 0); - if (conn.refcount) { + netio_close(conn.fd, true, (--sess.refcount) == 0); + if (sess.refcount) { cr_pause_and_yield(); - assert(conn.refcount == 0); + assert(sess.refcount == 0); netio_close(conn.fd, false, true); } } @@ -132,65 +190,118 @@ COROUTINE lib9p_srv_read_cr(void *_srv) { cr_end(); } +/* write coroutine ************************************************************/ + +static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp); + COROUTINE lib9p_srv_write_cr(void *_srv) { uint8_t net[CONFIG_9P_MAX_MSG_SIZE]; + uint8_t host_req[CONFIG_9P_MAX_HOSTMSG_SIZE]; + uint8_t host_resp[CONFIG_9P_MAX_HOSTMSG_SIZE]; - lib9p_srv *srv = _srv; + struct lib9p_srv *srv = _srv; assert(srv); cr_begin(); for (;;) { - struct lib9p_srvreq req; - cr_chan_recv(&srv->reqch, &req); - memcpy(net, req.msg, decode_u32le(req.msg)); - req.conn->refcount++; - cr_unpause(req.conn->reader); + /* Receive the request from the reader coroutine. */ + struct lib9p_req req; + struct lib9p_req *_req_p; + cr_chan_recv(&srv->reqch, &_req_p); + req = *_req_p; + memcpy(net, req.net_bytes, decode_u32le(req.net_bytes)); + req.net_bytes = net; + cr_unpause(req.parent_sess->parent_conn->reader); /* notify that we've copied req */ - handle_message(&req.conn, net); + /* Unmarshal it. */ + enum lib9p_msg_type typ = net[4]; + if (typ % 2 != 0) { + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "expected a T-message but got an R-message: message_type=%s", + lib9p_msg_type_str(typ)); + goto write; + } + ssize_t host_size = lib9p_unmarshal_size(&req.ctx, net); + if (host_size < 0) + goto write; + if ((size_t)host_size > sizeof(host_req)) { + lib9p_errorf(&req.ctx, LINUX_EMSGSIZE, "unmarshalled payload larger than server limit (%zu > %zu)", + host_size, sizeof(host_req)); + goto write; + } + lib9p_unmarshal(&req.ctx, net, &typ, &req.tag, host_req); + + /* Handle it. */ + switch (typ) { + case LIB9P_TYP_Tversion: + handle_Tversion(&req, (struct lib9p_msg_Tversion *)host_req, (struct lib9p_msg_Rversion *)host_resp); + break; + case LIB9P_TYP_Tauth: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tattach: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tflush: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Twalk: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Topen: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tcreate: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tread: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Twrite: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tclunk: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tremove: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tstat: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Twstat: + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tsession: /* 9P2000.e */ + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tsread: /* 9P2000.e */ + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + case LIB9P_TYP_Tswrite: /* 9P2000.e */ + lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); + break; + default: + assert(false); + } - if ((--req.conn->refcount) == 0) - cr_unpause(req.conn->reader); + write: + if (req.ctx.err_num || req.ctx.err_msg[0]) + respond_error(&req); + else { + if (lib9p_marshal(&req.ctx, typ+1, req.tag, host_resp, net)) + goto write; + cr_mutex_lock(&req.parent_sess->parent_conn->writelock); + netio_write(req.parent_sess->parent_conn->fd, net, decode_u32le(net)); + cr_mutex_unlock(&req.parent_sess->parent_conn->writelock); + } + if ((--req.parent_sess->refcount) == 0) + cr_unpause(req.parent_sess->parent_conn->reader); } cr_end(); } -void handle_message(struct lib9p_srvconn *conn, uint8_t *net) { - uint8_t host[CONFIG_9P_MAX_MSG_SIZE]; - - struct lib9p_ctx ctx = { - .version = req.conn->version, - .max_msg_size = req.conn->max_msg_size, - }; - - size_t host_size = lib9p_unmarshal_size(&ctx, net); - if (host_size == (size_t)-1) - goto write; - if (host_size > sizeof(host)) { - lib9p_errorf(&ctx, LINUX_EMSGSIZE, "unmarshalled payload larger than server limit (%zu > %zu)", host_size, sizeof(host)); - goto write; - } - - uint16_t tag; - uint8_t typ = lib9p_unmarshal(&ctx, net, &tag, host); - if (typ == (uint8_t)-1) - goto write; - if (typ % 2 != 0) { - lib9p_errorf(&ctx, LINUX_EOPNOTSUPP, "expected a T-message but got an R-message"); - goto write; - } - - TODO; - - write: - if (ctx.err_num || ctx.err_msg[0]) - marshal_error(&ctx, net); - else - TODO; - netio_write(req.conn->fd, net, decode_u32le(net)); -} - -void _version(struct lib9p_srv_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp) { +static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp) { enum lib9p_version version = LIB9P_VER_unknown; if (req->version.len >= 6 && @@ -201,30 +312,23 @@ void _version(struct lib9p_srv_req *ctx, struct lib9p_msg_Tversion *req, struct '0' <= req->version.utf8[4] && req->version.utf8[4] <= '9' && '0' <= req->version.utf8[5] && req->version.utf8[5] <= '9' && (req->version.utf8[6] == '\0' || req->version.utf8[6] == '.')) { - if (strcmp(&req->version.utf8[6], ".u") == 0) + if (strcmp((char *)&req->version.utf8[6], ".u") == 0) version = LIB9P_VER_9P2000_u; - //else if (strcmp(&req->version.utf8[6], ".e") == 0) - // version = LIB9P_VER_9P2000_e; + else if (strcmp((char *)&req->version.utf8[6], ".e") == 0) + version = LIB9P_VER_9P2000_e; else version = LIB9P_VER_9P2000; } - struct lib9p_ctx empty_ctx = { - .version = version, - .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, - }; - struct lib9p_msg_Rerror empty_error = { 0 }; - assert(!lib9p_marshal(&empty_ctx, LIB9P_TYP_Rerror, 0, &empty_error, ctx->net)); - uint32_t min_msg_size = decode_u32le(ctx->net); - assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size); - + uint32_t min_msg_size = rerror_overhead_for_version(version, ctx->net_bytes); if (req->max_msg_size < min_msg_size) { lib9p_errorf(&ctx->ctx, LINUX_EDOM, "requested max_msg_size is less than minimum for %s (%"PRIu32" < %"PRIu32")", version, req->max_msg_size, min_msg_size); return; } - resp->version = lib9p_version_str(version); + resp->version.utf8 = (uint8_t *)lib9p_version_str(version); + resp->version.len = strlen((char *)resp->version.utf8); resp->max_msg_size = (CONFIG_9P_MAX_MSG_SIZE < req->max_msg_size) ? CONFIG_9P_MAX_MSG_SIZE : req->max_msg_size; diff --git a/lib9p/types.c b/lib9p/types.c index 8a48b85..0c03bd5 100644 --- a/lib9p/types.c +++ b/lib9p/types.c @@ -24,7 +24,7 @@ const char *lib9p_version_str(enum lib9p_version ver) { return version_strs[ver]; } -static const char *msg_type_strs[0xFF] = { +static const char *msg_type_strs[0x100] = { [0x00] = "0x00", [0x01] = "0x01", [0x02] = "0x02", @@ -917,9 +917,6 @@ static bool marshal_Rattach(struct _marshal_ctx *ctx, struct lib9p_msg_Rattach * } static bool marshal_Rerror(struct _marshal_ctx *ctx, struct lib9p_msg_Rerror *val) { - /* Truncate the error-string if necessary to avoid returning ERANGE. */ - if (((uint32_t)val->ename.len) + ctx->ctx->Rerror_overhead > ctx->ctx->max_msg_size) - val->ename.len = ctx->ctx->max_msg_size - ctx->ctx->Rerror_overhead; return marshal_s(ctx, &val->ename) || ( (ctx->ctx->version==LIB9P_VER_9P2000_u) && marshal_4(ctx, &val->errno) ); } diff --git a/lib9p/types.gen b/lib9p/types.gen index 590e707..a597022 100755 --- a/lib9p/types.gen +++ b/lib9p/types.gen @@ -359,7 +359,7 @@ const char *{idprefix}version_str(enum {idprefix}version ver) {{ return version_strs[ver]; }} -static const char *msg_type_strs[0xFF] = {{ +static const char *msg_type_strs[0x100] = {{ """ id2name: dict[int, str] = {} for msg in structs: @@ -577,11 +577,6 @@ static inline bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { ret += "}\n" continue - if struct.name == "Rerror": # SPECIAL - ret += "\n\t/* Truncate the error-string if necessary to avoid returning ERANGE. */" - ret += "\n\tif (((uint32_t)val->ename.len) + ctx->ctx->Rerror_overhead > ctx->ctx->max_msg_size)" - ret += "\n\t\tval->ename.len = ctx->ctx->max_msg_size - ctx->ctx->Rerror_overhead;" - prefix0 = "\treturn " prefix1 = "\t || " prefix2 = "\t " |