#include #include /* for PRI* */ #include /* for fprintf(), stderr */ #include /* for strerror() */ #include #include #include #include #include #include #include "internal.h" /* structs ********************************************************************/ /* The hierarchy of concepts is: * * server -> connection -> session -> request * */ /* struct lib9p_srv {} is defined in */ struct lib9p_conn { /* immutable */ struct lib9p_srv *parent_srv; int fd; cid_t reader; /* the lib9p_srv_read_cr() coroutine */ /* mutable */ cr_mutex_t writelock; }; struct lib9p_sess { /* immutable */ struct lib9p_conn *parent_conn; enum lib9p_version version; uint32_t max_msg_size; uint32_t rerror_overhead; /* mutable */ bool initialized; unsigned int refcount; }; struct lib9p_req { /* immutable */ struct lib9p_sess *parent_sess; uint16_t tag; /* mutable */ uint8_t *net_bytes; /* CONFIG_9P_MAX_MSG_SIZE-sized */ struct lib9p_ctx ctx; }; /* base utilities *************************************************************/ #define nonrespond_errorf(fmt, ...) fprintf(stderr, "error: " fmt "\n" __VA_OPT__(,) __VA_ARGS__) static uint32_t rerror_overhead_for_version(enum lib9p_version version, uint8_t *scratch) { struct lib9p_ctx empty_ctx = { .version = version, .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, }; struct lib9p_msg_Rerror empty_error = { 0 }; assert(!lib9p_marshal(&empty_ctx, LIB9P_TYP_Rerror, 0, &empty_error, scratch)); uint32_t min_msg_size = decode_u32le(scratch); assert(min_msg_size < (UINT32_MAX - UINT16_MAX)); assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size); return min_msg_size; } static void respond_error(struct lib9p_req *req) { assert(req->ctx.err_num); assert(req->ctx.err_msg[0]); ssize_t r; struct lib9p_msg_Rerror host = { .ename = { .len = strnlen(req->ctx.err_msg, CONFIG_9P_MAX_ERR_SIZE), .utf8 = (uint8_t*)req->ctx.err_msg, }, .errno = req->ctx.err_num, }; /* Truncate the error-string if necessary to avoid needing to return ERANGE. */ if (((uint32_t)host.ename.len) + req->parent_sess->rerror_overhead > req->parent_sess->max_msg_size) host.ename.len = req->parent_sess->max_msg_size - req->parent_sess->rerror_overhead; lib9p_marshal(&req->ctx, LIB9P_TYP_Rerror, req->tag, &host, req->net_bytes); cr_mutex_lock(&req->parent_sess->parent_conn->writelock); r = netio_write(req->parent_sess->parent_conn->fd, req->net_bytes, decode_u32le(req->net_bytes)); cr_mutex_unlock(&req->parent_sess->parent_conn->writelock); if (r < 0) nonrespond_errorf("write: %s", strerror(-r)); } /* read coroutine *************************************************************/ static bool read_at_least(int fd, uint8_t *buf, size_t goal, size_t *done) { assert(buf); assert(goal); assert(done); while (*done < goal) { ssize_t r = netio_read(fd, &buf[*done], CONFIG_9P_MAX_MSG_SIZE - *done); if (r < 0) { nonrespond_errorf("read: %s", strerror(-r)); return true; } else if (r == 0) { if (*done != 0) nonrespond_errorf("read: unexpected EOF"); return true; } *done += r; } return false; } COROUTINE lib9p_srv_read_cr(void *_srv) { uint8_t buf[CONFIG_9P_MAX_MSG_SIZE]; struct lib9p_srv *srv = _srv; assert(srv); cr_begin(); uint32_t initial_rerror_overhead = rerror_overhead_for_version(0, buf); for (;;) { struct lib9p_conn conn = { .parent_srv = srv, .fd = netio_accept(srv->sockfd), .reader = cr_getcid(), }; if (conn.fd < 0) { nonrespond_errorf("accept: %s", strerror(-conn.fd)); continue; } struct lib9p_sess sess = { .parent_conn = &conn, .version = 0, .max_msg_size = CONFIG_9P_MAX_MSG_SIZE, .rerror_overhead = initial_rerror_overhead, .initialized = false, .refcount = 1, }; for (;;) { /* Read the message. */ size_t done = 0; if (read_at_least(conn.fd, buf, 4, &done)) goto close; size_t goal = decode_u32le(buf); if (goal < 7) { nonrespond_errorf("T-message is impossibly small"); goto close; } if (read_at_least(conn.fd, buf, 7, &done)) goto close; struct lib9p_req req = { .parent_sess = &sess, .tag = decode_u16le(&buf[5]), .net_bytes = buf, .ctx = { .version = sess.version, .max_msg_size = sess.max_msg_size, }, }; if (goal > sess.max_msg_size) { lib9p_errorf(&req.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %"PRIu32")", sess.initialized ? "negotiated" : "server", goal, sess.max_msg_size); respond_error(&req); continue; } if (read_at_least(conn.fd, buf, goal, &done)) goto close; /* Handle the message... in another coroutine. */ sess.refcount++; cr_chan_send(&srv->reqch, &req); cr_pause_and_yield(); /* wait for it to have copied req */ } close: netio_close(conn.fd, true, (--sess.refcount) == 0); if (sess.refcount) { cr_pause_and_yield(); assert(sess.refcount == 0); netio_close(conn.fd, false, true); } } cr_end(); } /* write coroutine ************************************************************/ static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp); COROUTINE lib9p_srv_write_cr(void *_srv) { uint8_t net[CONFIG_9P_MAX_MSG_SIZE]; uint8_t host_req[CONFIG_9P_MAX_HOSTMSG_SIZE]; uint8_t host_resp[CONFIG_9P_MAX_HOSTMSG_SIZE]; struct lib9p_srv *srv = _srv; assert(srv); cr_begin(); for (;;) { /* Receive the request from the reader coroutine. */ struct lib9p_req req; struct lib9p_req *_req_p; cr_chan_recv(&srv->reqch, &_req_p); req = *_req_p; memcpy(net, req.net_bytes, decode_u32le(req.net_bytes)); req.net_bytes = net; cr_unpause(req.parent_sess->parent_conn->reader); /* notify that we've copied req */ /* Unmarshal it. */ enum lib9p_msg_type typ = net[4]; if (typ % 2 != 0) { lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "expected a T-message but got an R-message: message_type=%s", lib9p_msg_type_str(typ)); goto write; } ssize_t host_size = lib9p_unmarshal_size(&req.ctx, net); if (host_size < 0) goto write; if ((size_t)host_size > sizeof(host_req)) { lib9p_errorf(&req.ctx, LINUX_EMSGSIZE, "unmarshalled payload larger than server limit (%zu > %zu)", host_size, sizeof(host_req)); goto write; } lib9p_unmarshal(&req.ctx, net, &typ, &req.tag, host_req); /* Handle it. */ switch (typ) { case LIB9P_TYP_Tversion: handle_Tversion(&req, (struct lib9p_msg_Tversion *)host_req, (struct lib9p_msg_Rversion *)host_resp); break; case LIB9P_TYP_Tauth: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tattach: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tflush: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Twalk: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Topen: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tcreate: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tread: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Twrite: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tclunk: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tremove: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tstat: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Twstat: lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tsession: /* 9P2000.e */ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tsread: /* 9P2000.e */ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; case LIB9P_TYP_Tswrite: /* 9P2000.e */ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ)); break; default: assert(false); } write: if (req.ctx.err_num || req.ctx.err_msg[0]) respond_error(&req); else { if (lib9p_marshal(&req.ctx, typ+1, req.tag, host_resp, net)) goto write; cr_mutex_lock(&req.parent_sess->parent_conn->writelock); netio_write(req.parent_sess->parent_conn->fd, net, decode_u32le(net)); cr_mutex_unlock(&req.parent_sess->parent_conn->writelock); } if ((--req.parent_sess->refcount) == 0) cr_unpause(req.parent_sess->parent_conn->reader); } cr_end(); } static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp) { enum lib9p_version version = LIB9P_VER_unknown; if (req->version.len >= 6 && req->version.utf8[0] == '9' && req->version.utf8[1] == 'P' && '0' <= req->version.utf8[2] && req->version.utf8[2] <= '9' && '0' <= req->version.utf8[3] && req->version.utf8[3] <= '9' && '0' <= req->version.utf8[4] && req->version.utf8[4] <= '9' && '0' <= req->version.utf8[5] && req->version.utf8[5] <= '9' && (req->version.utf8[6] == '\0' || req->version.utf8[6] == '.')) { if (strcmp((char *)&req->version.utf8[6], ".u") == 0) version = LIB9P_VER_9P2000_u; else if (strcmp((char *)&req->version.utf8[6], ".e") == 0) version = LIB9P_VER_9P2000_e; else version = LIB9P_VER_9P2000; } uint32_t min_msg_size = rerror_overhead_for_version(version, ctx->net_bytes); if (req->max_msg_size < min_msg_size) { lib9p_errorf(&ctx->ctx, LINUX_EDOM, "requested max_msg_size is less than minimum for %s (%"PRIu32" < %"PRIu32")", version, req->max_msg_size, min_msg_size); return; } resp->version.utf8 = (uint8_t *)lib9p_version_str(version); resp->version.len = strlen((char *)resp->version.utf8); resp->max_msg_size = (CONFIG_9P_MAX_MSG_SIZE < req->max_msg_size) ? CONFIG_9P_MAX_MSG_SIZE : req->max_msg_size; }