summaryrefslogtreecommitdiff
path: root/lib9p/srv.c
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/srv.c')
-rw-r--r--lib9p/srv.c368
1 files changed, 236 insertions, 132 deletions
diff --git a/lib9p/srv.c b/lib9p/srv.c
index 74c2014..997d137 100644
--- a/lib9p/srv.c
+++ b/lib9p/srv.c
@@ -1,48 +1,119 @@
#include <assert.h>
-#include <stdio.h> /* for fprintf(), stderr */
-#include <string.h> /* for strerror() */
+#include <inttypes.h> /* for PRI* */
+#include <stdio.h> /* for fprintf(), stderr */
+#include <string.h> /* for strerror() */
#include <libcr/coroutine.h>
#include <libcr_ipc/chan.h>
+#include <libcr_ipc/mutex.h>
#include <libnetio/netio.h>
#include <lib9p/9p.h>
#include <lib9p/srv.h>
#include "internal.h"
+/* structs ********************************************************************/
+
/* The hierarchy of concepts is:
*
* server -> connection -> session -> request
*
*/
-struct lib9p_srv_sess {
+/* struct lib9p_srv {} is defined in <lib9p/srv.h> */
+
+struct lib9p_conn {
/* immutable */
- struct lib9p_srv *srv;
- int connfd;
- cid_t reader; /* the lib9p_srv_read_cr() coroutine for this session */
+ struct lib9p_srv *parent_srv;
+ int fd;
+ cid_t reader; /* the lib9p_srv_read_cr() coroutine */
/* mutable */
- struct lib9p_ctx ctx;
+ cr_mutex_t writelock;
+};
+
+struct lib9p_sess {
+ /* immutable */
+ struct lib9p_conn *parent_conn;
+ enum lib9p_version version;
+ uint32_t max_msg_size;
+ uint32_t rerror_overhead;
+ /* mutable */
+ bool initialized;
unsigned int refcount;
};
-struct lib9p_srv_req {
- struct lib9p_srv_sess *sess;
- uint8_t *msg;
+struct lib9p_req {
+ /* immutable */
+ struct lib9p_sess *parent_sess;
+ uint16_t tag;
+ /* mutable */
+ uint8_t *net_bytes; /* CONFIG_9P_MAX_MSG_SIZE-sized */
+ struct lib9p_ctx ctx;
};
-static void marshal_error(struct lib9p_ctx *ctx, uint16_t tag, uint8_t *net) {
+/* base utilities *************************************************************/
+
+#define nonrespond_errorf(...) fprintf(stderr, "error: " __VA_ARGS__)
+
+static uint32_t rerror_overhead_for_version(enum lib9p_version version, uint8_t *scratch) {
+ struct lib9p_ctx empty_ctx = {
+ .version = version,
+ .max_msg_size = CONFIG_9P_MAX_MSG_SIZE,
+ };
+ struct lib9p_msg_Rerror empty_error = { 0 };
+
+ assert(!lib9p_marshal(&empty_ctx, LIB9P_TYP_Rerror, 0, &empty_error, scratch));
+ uint32_t min_msg_size = decode_u32le(scratch);
+ assert(min_msg_size < (UINT32_MAX - UINT16_MAX));
+ assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size);
+
+ return min_msg_size;
+}
+
+static void respond_error(struct lib9p_req *req) {
+ assert(req->ctx.err_num);
+ assert(req->ctx.err_msg[0]);
+
+ ssize_t r;
struct lib9p_msg_Rerror host = {
.ename = {
- .len = strnlen(ctx->err_msg, CONFIG_9P_MAX_ERR_SIZE),
- .utf8 = (uint8_t*)ctx->err_msg,
+ .len = strnlen(req->ctx.err_msg, CONFIG_9P_MAX_ERR_SIZE),
+ .utf8 = (uint8_t*)req->ctx.err_msg,
},
- .errno = ctx->err_num,
+ .errno = req->ctx.err_num,
};
- lib9p_marshal(ctx, LIB9P_TYP_Rerror, tag, &host, net);
+
+ /* Truncate the error-string if necessary to avoid needing to return ERANGE. */
+ if (((uint32_t)host.ename.len) + req->parent_sess->rerror_overhead > req->parent_sess->max_msg_size)
+ host.ename.len = req->parent_sess->max_msg_size - req->parent_sess->rerror_overhead;
+
+ lib9p_marshal(&req->ctx, LIB9P_TYP_Rerror, req->tag, &host, req->net_bytes);
+
+ cr_mutex_lock(&req->parent_sess->parent_conn->writelock);
+ r = netio_write(req->parent_sess->parent_conn->fd, req->net_bytes, decode_u32le(req->net_bytes));
+ cr_mutex_unlock(&req->parent_sess->parent_conn->writelock);
+ if (r < 0)
+ nonrespond_errorf("write: %s", strerror(-r));
}
-void handle_message(struct lib9p_srvconn *conn, uint8_t *net);
+/* read coroutine *************************************************************/
+
+static bool read_at_least(int fd, uint8_t *buf, size_t goal, size_t *done) {
+ while (*done < goal) {
+ ssize_t r = netio_read(fd, &buf[*done], CONFIG_9P_MAX_MSG_SIZE - *done);
+ if (r < 0) {
+ nonrespond_errorf("read: %s", strerror(-r));
+ return true;
+ } else if (r == 0) {
+ if (*done != 0) {
+ nonrespond_errorf("read: unexpected EOF");
+ return true;
+ }
+ *done += r;
+ }
+ }
+ return false;
+}
COROUTINE lib9p_srv_read_cr(void *_srv) {
uint8_t buf[CONFIG_9P_MAX_MSG_SIZE];
@@ -51,80 +122,67 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
assert(srv);
cr_begin();
+ uint32_t initial_rerror_overhead = rerror_overhead_for_version(0, buf);
+
for (;;) {
- struct lib9p_srvconn conn = {
- .srv = srv,
- .reader = cr_getcid(),
-
- .ctx = {
- .version = LIB9P_VER_UNINITIALIZED,
- .max_msg_size = CONFIG_9P_MAX_MSG_SIZE,
- },
- .refcount = 1,
+ struct lib9p_conn conn = {
+ .parent_srv = srv,
+ .fd = netio_accept(srv->sockfd),
+ .reader = cr_getcid(),
};
- conn.fd = netio_accept(srv->sockfd);
if (conn.fd < 0) {
- fprintf(stderr, "error: accept: %s", strerror(-conn.fd));
+ nonrespond_errorf("accept: %s", strerror(-conn.fd));
continue;
}
+ struct lib9p_sess sess = {
+ .parent_conn = &conn,
+ .version = 0,
+ .max_msg_size = CONFIG_9P_MAX_MSG_SIZE,
+ .rerror_overhead = initial_rerror_overhead,
+ .initialized = false,
+ .refcount = 1,
+ };
for (;;) {
- /* Read the message size. */
- size_t goal = 4, done = 0;
- while (done < goal) {
- ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done);
- if (r < 0) {
- fprintf(stderr, "error: read: %s", strerror(-r));
- goto close;
- } else if (r == 0) {
- if (done != 0)
- fprintf(stderr, "error: read: unexpected EOF");
- goto close;
- }
- done += r;
- }
- goal = decode_u32le(buf);
+ /* Read the message. */
+ size_t done = 0;
+ if (read_at_least(conn.fd, buf, 4, &done))
+ goto close;
+ size_t goal = decode_u32le(buf);
if (goal < 7) {
- /* We can't even respond with an Rerror becuase we wouldn't know what tag to use! */
- fprintf(stderr, "error: T-message is impossibly small");
+ nonrespond_errorf("T-message is impossibly small");
goto close;
}
- if (goal > conn.ctx.max_msg_size) {
- lib9p_errorf(&conn.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %zu)",
- conn.ctx.version ? "negotiated" : "server", goal, conn.ctx.max_msg_size);
- uint16_t tag = decode_u16le(&buf[5]);
- marshal_error(&conn.ctx, tag, buf);
- netio_write(conn.fd, buf, decode_u32le(buf));
+ if (read_at_least(conn.fd, buf, 7, &done))
+ goto close;
+ struct lib9p_req req = {
+ .parent_sess = &sess,
+ .tag = decode_u16le(&buf[5]),
+ .net_bytes = buf,
+ .ctx = {
+ .version = sess.version,
+ .max_msg_size = sess.max_msg_size,
+ },
+ };
+ if (goal > sess.max_msg_size) {
+ lib9p_errorf(&req.ctx, LINUX_EMSGSIZE, "T-message larger than %s limit (%zu > %"PRIu32")",
+ sess.initialized ? "negotiated" : "server", goal, sess.max_msg_size);
+ respond_error(&req);
continue;
}
- /* Read the rest of the message. */
- while (done < goal) {
- ssize_t r = netio_read(conn.fd, &buf[done], sizeof(buf)-done);
- if (r < 0) {
- fprintf(stderr, "error: read: %s", strerror(-r));
- goto close;
- } else if (r == 0) {
- fprintf(stderr, "error: read: unexpected EOF");
- goto close;
- }
- done += r;
- }
+ if (read_at_least(conn.fd, buf, goal, &done))
+ goto close;
- /* Handle the message... */
- if (conn.ctx.version == LIB9P_VER_UNINITIALIZED) {
- /* ...synchronously if we haven't negotiated the protocol yet, ... */
- handle_message(&conn, buf);
- } else {
- /* ...asynchronously if we have. */
- cr_chan_send(&srv->reqch, buf);
- cr_pause_and_yield();
- }
+ /* Handle the message... in another coroutine. */
+ sess.refcount++;
+ cr_chan_send(&srv->reqch, &req);
+ cr_pause_and_yield(); /* wait for it to have copied req */
}
close:
- netio_close(conn.fd, true, (--conn.refcount) == 0);
- if (conn.refcount) {
+ netio_close(conn.fd, true, (--sess.refcount) == 0);
+ if (sess.refcount) {
cr_pause_and_yield();
- assert(conn.refcount == 0);
+ assert(sess.refcount == 0);
netio_close(conn.fd, false, true);
}
}
@@ -132,65 +190,118 @@ COROUTINE lib9p_srv_read_cr(void *_srv) {
cr_end();
}
+/* write coroutine ************************************************************/
+
+static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp);
+
COROUTINE lib9p_srv_write_cr(void *_srv) {
uint8_t net[CONFIG_9P_MAX_MSG_SIZE];
+ uint8_t host_req[CONFIG_9P_MAX_HOSTMSG_SIZE];
+ uint8_t host_resp[CONFIG_9P_MAX_HOSTMSG_SIZE];
- lib9p_srv *srv = _srv;
+ struct lib9p_srv *srv = _srv;
assert(srv);
cr_begin();
for (;;) {
- struct lib9p_srvreq req;
- cr_chan_recv(&srv->reqch, &req);
- memcpy(net, req.msg, decode_u32le(req.msg));
- req.conn->refcount++;
- cr_unpause(req.conn->reader);
+ /* Receive the request from the reader coroutine. */
+ struct lib9p_req req;
+ struct lib9p_req *_req_p;
+ cr_chan_recv(&srv->reqch, &_req_p);
+ req = *_req_p;
+ memcpy(net, req.net_bytes, decode_u32le(req.net_bytes));
+ req.net_bytes = net;
+ cr_unpause(req.parent_sess->parent_conn->reader); /* notify that we've copied req */
- handle_message(&req.conn, net);
+ /* Unmarshal it. */
+ enum lib9p_msg_type typ = net[4];
+ if (typ % 2 != 0) {
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "expected a T-message but got an R-message: message_type=%s",
+ lib9p_msg_type_str(typ));
+ goto write;
+ }
+ ssize_t host_size = lib9p_unmarshal_size(&req.ctx, net);
+ if (host_size < 0)
+ goto write;
+ if ((size_t)host_size > sizeof(host_req)) {
+ lib9p_errorf(&req.ctx, LINUX_EMSGSIZE, "unmarshalled payload larger than server limit (%zu > %zu)",
+ host_size, sizeof(host_req));
+ goto write;
+ }
+ lib9p_unmarshal(&req.ctx, net, &typ, &req.tag, host_req);
+
+ /* Handle it. */
+ switch (typ) {
+ case LIB9P_TYP_Tversion:
+ handle_Tversion(&req, (struct lib9p_msg_Tversion *)host_req, (struct lib9p_msg_Rversion *)host_resp);
+ break;
+ case LIB9P_TYP_Tauth:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tattach:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tflush:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Twalk:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Topen:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tcreate:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tread:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Twrite:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tclunk:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tremove:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tstat:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Twstat:
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tsession: /* 9P2000.e */
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tsread: /* 9P2000.e */
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ case LIB9P_TYP_Tswrite: /* 9P2000.e */
+ lib9p_errorf(&req.ctx, LINUX_EOPNOTSUPP, "%s not yet implemented", lib9p_msg_type_str(typ));
+ break;
+ default:
+ assert(false);
+ }
- if ((--req.conn->refcount) == 0)
- cr_unpause(req.conn->reader);
+ write:
+ if (req.ctx.err_num || req.ctx.err_msg[0])
+ respond_error(&req);
+ else {
+ if (lib9p_marshal(&req.ctx, typ+1, req.tag, host_resp, net))
+ goto write;
+ cr_mutex_lock(&req.parent_sess->parent_conn->writelock);
+ netio_write(req.parent_sess->parent_conn->fd, net, decode_u32le(net));
+ cr_mutex_unlock(&req.parent_sess->parent_conn->writelock);
+ }
+ if ((--req.parent_sess->refcount) == 0)
+ cr_unpause(req.parent_sess->parent_conn->reader);
}
cr_end();
}
-void handle_message(struct lib9p_srvconn *conn, uint8_t *net) {
- uint8_t host[CONFIG_9P_MAX_MSG_SIZE];
-
- struct lib9p_ctx ctx = {
- .version = req.conn->version,
- .max_msg_size = req.conn->max_msg_size,
- };
-
- size_t host_size = lib9p_unmarshal_size(&ctx, net);
- if (host_size == (size_t)-1)
- goto write;
- if (host_size > sizeof(host)) {
- lib9p_errorf(&ctx, LINUX_EMSGSIZE, "unmarshalled payload larger than server limit (%zu > %zu)", host_size, sizeof(host));
- goto write;
- }
-
- uint16_t tag;
- uint8_t typ = lib9p_unmarshal(&ctx, net, &tag, host);
- if (typ == (uint8_t)-1)
- goto write;
- if (typ % 2 != 0) {
- lib9p_errorf(&ctx, LINUX_EOPNOTSUPP, "expected a T-message but got an R-message");
- goto write;
- }
-
- TODO;
-
- write:
- if (ctx.err_num || ctx.err_msg[0])
- marshal_error(&ctx, net);
- else
- TODO;
- netio_write(req.conn->fd, net, decode_u32le(net));
-}
-
-void _version(struct lib9p_srv_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp) {
+static void handle_Tversion(struct lib9p_req *ctx, struct lib9p_msg_Tversion *req, struct lib9p_msg_Rversion *resp) {
enum lib9p_version version = LIB9P_VER_unknown;
if (req->version.len >= 6 &&
@@ -201,30 +312,23 @@ void _version(struct lib9p_srv_req *ctx, struct lib9p_msg_Tversion *req, struct
'0' <= req->version.utf8[4] && req->version.utf8[4] <= '9' &&
'0' <= req->version.utf8[5] && req->version.utf8[5] <= '9' &&
(req->version.utf8[6] == '\0' || req->version.utf8[6] == '.')) {
- if (strcmp(&req->version.utf8[6], ".u") == 0)
+ if (strcmp((char *)&req->version.utf8[6], ".u") == 0)
version = LIB9P_VER_9P2000_u;
- //else if (strcmp(&req->version.utf8[6], ".e") == 0)
- // version = LIB9P_VER_9P2000_e;
+ else if (strcmp((char *)&req->version.utf8[6], ".e") == 0)
+ version = LIB9P_VER_9P2000_e;
else
version = LIB9P_VER_9P2000;
}
- struct lib9p_ctx empty_ctx = {
- .version = version,
- .max_msg_size = CONFIG_9P_MAX_MSG_SIZE,
- };
- struct lib9p_msg_Rerror empty_error = { 0 };
- assert(!lib9p_marshal(&empty_ctx, LIB9P_TYP_Rerror, 0, &empty_error, ctx->net));
- uint32_t min_msg_size = decode_u32le(ctx->net);
- assert(CONFIG_9P_MAX_MSG_SIZE >= min_msg_size);
-
+ uint32_t min_msg_size = rerror_overhead_for_version(version, ctx->net_bytes);
if (req->max_msg_size < min_msg_size) {
lib9p_errorf(&ctx->ctx, LINUX_EDOM, "requested max_msg_size is less than minimum for %s (%"PRIu32" < %"PRIu32")",
version, req->max_msg_size, min_msg_size);
return;
}
- resp->version = lib9p_version_str(version);
+ resp->version.utf8 = (uint8_t *)lib9p_version_str(version);
+ resp->version.len = strlen((char *)resp->version.utf8);
resp->max_msg_size = (CONFIG_9P_MAX_MSG_SIZE < req->max_msg_size)
? CONFIG_9P_MAX_MSG_SIZE
: req->max_msg_size;