summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--coroutine.c53
-rw-r--r--coroutine.h2
-rw-r--r--coroutine_sema.c95
-rw-r--r--coroutine_sema.h35
-rw-r--r--netio_posix.c44
5 files changed, 189 insertions, 40 deletions
diff --git a/coroutine.c b/coroutine.c
index f138fcf..12d4a25 100644
--- a/coroutine.c
+++ b/coroutine.c
@@ -154,10 +154,11 @@ enum coroutine_state {
*/
struct coroutine {
- enum coroutine_state state;
- jmp_buf env;
- size_t stack_size;
- void *stack;
+ volatile enum coroutine_state state;
+ volatile bool sig_unpause;
+ jmp_buf env;
+ size_t stack_size;
+ void *stack;
};
static struct coroutine coroutine_table[CONFIG_COROUTINE_NUM] = {0};
@@ -213,8 +214,10 @@ static void call_with_stack(void *stack, cr_fn_t fn, void *args) {
static const uint8_t stack_pattern[] = {0x1e, 0x15, 0x16, 0x0a, 0xcc, 0x52, 0x7e, 0xb7};
#endif
+static void inline assert_cid(cid_t cid) {
+ assert(cid > 0);
+ assert(cid <= CONFIG_COROUTINE_NUM);
#if CONFIG_COROUTINE_PROTECT_STACK
-void assert_stack_protection(cid_t cid) {
assert(coroutine_table[cid-1].stack_size);
assert(coroutine_table[cid-1].stack);
for (size_t i = 0; i < sizeof(stack_pattern); i++) {
@@ -222,16 +225,12 @@ void assert_stack_protection(cid_t cid) {
assert(((uint8_t*)coroutine_table[cid-1].stack)[i] == stack_pattern[i]);
assert(((uint8_t*)coroutine_table[cid-1].stack)[j] == stack_pattern[j%sizeof(stack_pattern)]);
}
-}
-#else
-# define assert_stack_protection(cid) ((void)0)
#endif
+}
#define assert_cid_state(cid, opstate) do { \
- assert((cid) > 0); \
- assert((cid) <= CONFIG_COROUTINE_NUM); \
+ assert_cid(cid); \
assert(coroutine_table[(cid)-1].state opstate); \
- assert_stack_protection(cid); \
} while (0)
cid_t coroutine_add_with_stack_size(size_t stack_size, cr_fn_t fn, void *args) {
@@ -344,7 +343,6 @@ void cr_begin(void) {
}
static inline void _cr_transition(enum coroutine_state state) {
- assert_cid_state(coroutine_running, == CR_RUNNING);
debugf("cid=%zu: transition %i->%i\n", coroutine_running, coroutine_table[coroutine_running-1].state, state);
coroutine_table[coroutine_running-1].state = state;
@@ -352,8 +350,19 @@ static inline void _cr_transition(enum coroutine_state state) {
longjmp(coroutine_main_env, 1); /* jump to point=b */
}
-void cr_yield(void) { _cr_transition(CR_RUNNABLE); }
-void cr_pause_and_yield(void) { _cr_transition(CR_PAUSED); }
+void cr_yield(void) {
+ assert_cid_state(coroutine_running, == CR_RUNNING);
+ _cr_transition(CR_RUNNABLE);
+}
+
+void cr_pause_and_yield(void) {
+ assert_cid_state(coroutine_running, == CR_RUNNING);
+ if (coroutine_table[cid-1].sig_unpause)
+ _cr_transition(CR_RUNNABLE);
+ else
+ _cr_transition(CR_PAUSED);
+ coroutine_table[cid-1].sig_unpause = false;
+}
void cr_exit(void) {
assert_cid_state(coroutine_running, == CR_RUNNING);
@@ -367,7 +376,21 @@ void cr_unpause(cid_t cid) {
assert_cid_state(cid, == CR_PAUSED);
debugf("cr_unpause(%zu)\n", cid);
- coroutine_table[cid-1].state = CR_RUNNABLE;
+ coroutine_table[cid-1].sig_unpause = false;
+ coroutine_table[cid-1].state = CR_RUNNABLE;
+}
+
+void cr_unpause_from_sighandler(cid_t cid) {
+ assert_cid(cid);
+
+ switch (coroutine_table[cid-1].state) {
+ case CR_RUNNING:
+ coroutine_table[cid-1].sig_unpause = true;
+ case CR_PAUSED:
+ coroutine_table[cid-1].state = CR_RUNNABLE;
+ default:
+ assert(false);
+ }
}
cid_t cr_getcid(void) {
diff --git a/coroutine.h b/coroutine.h
index 4bdc4f8..4d83181 100644
--- a/coroutine.h
+++ b/coroutine.h
@@ -114,6 +114,8 @@ void cr_yield(void);
void cr_pause_and_yield(void);
/** cr_unpause() marks a coroutine as runnable that has previously marked itself as non-runnable with cr_pause_and_yield(). */
void cr_unpause(cid_t);
+/** cr_unpause_from_sighandler() is like cr_unpause(), but safe to call from a signal handler that might race with the coroutine actually pausing itself. */
+void cr_unpause_from_sighandler(cid_t);
/** cr_end() is a counterpart to cr_begin(), but is really just cr_exit(). */
#define cr_end cr_exit
diff --git a/coroutine_sema.c b/coroutine_sema.c
new file mode 100644
index 0000000..a8b5ec4
--- /dev/null
+++ b/coroutine_sema.c
@@ -0,0 +1,95 @@
+/* coroutine_sema.h - Simple semaphores for coroutine.{h,c}
+ *
+ * Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
+ * SPDX-Licence-Identifier: AGPL-3.0-or-later
+ */
+
+#include <assert.h>
+
+#include "coroutine_sema.h"
+
+struct cid_list {
+ cid_t val;
+ struct cid_list *next;
+};
+
+/* head->next->next->tail */
+
+struct _cr_sema {
+ int cnt;
+
+ struct cid_list *head, **tail;
+ /* locked indicates that a call from within a coroutine is is
+ * messing with ->{head,tail}, so a signal handler can't read
+ * it. */
+ bool locked;
+};
+
+/** Drain the sema->{head,tail} list. Returns true if cr_getcid() was drained. */
+static inline bool drain(volatile cr_sema_t *sema) {
+ assert(!sema->locked);
+ cid_t self = cr_getcid();
+
+ enum drain_result {
+ DRAINING,
+ DRAINED_SELF, /* stopped because drained `self` */
+ DRAINED_ALL, /* stopped because sema->head == NULL */
+ DRAINED_SOME, /* stopped because sema->cnt == 0 */
+ } state = DRAINING;
+ do {
+ sema->locked = true;
+ while (state == DRAINING) {
+ if (!sema->head) {
+ state = DRAINED_ALL;
+ } else if (!sema->cnt) {
+ state = DRAINED_SOME;
+ } else {
+ sema->cnt--;
+ cid_t cid = sema->head->val;
+ if (cid == self)
+ state = DRAINED_SELF;
+ else
+ cr_unpause(sema->head->val);
+ sema->head = sema->head->next;
+ if (!sema->head)
+ sema->tail = &sema->head;
+ }
+ }
+ sema->locked = false;
+ /* If there are still coroutines in sema->head, check
+ * that sema->cnt wasn't incremented between `if
+ * (!sema->cnt)` and `sema->locked = false`. */
+ } while (state == DRAINED_SOME && cnt);
+ /* If state == DRAINED_SELF, then we better have been the last
+ * item in the list! */
+ assert(state != DRAINED_SELF || !sema->head);
+ return state == DRAINED_SELF;
+}
+
+void cr_sema_signal(volatile cr_sema_t *sema) {
+ sema->cnt++;
+ if (!sema->locked)
+ drain();
+}
+
+void cr_sema_wait(volatile cr_sema_t *sema) {
+ struct cid_list self = {
+ .val = cr_getcid(),
+ .next = NULL,
+ };
+
+ sema->locked = true;
+ if (!sema->tail)
+ sema->head = &self;
+ else
+ *(sema->tail) = &self;
+ sema->tail = &(self.next);
+ sema->locked = false;
+
+ if (drain())
+ /* DRAINED_SELF: (1) No need to pause+yield, (2) we
+ * better have been the last item in the list! */
+ assert(!self.next);
+ else
+ cr_pause_and_yield();
+}
diff --git a/coroutine_sema.h b/coroutine_sema.h
new file mode 100644
index 0000000..57d5855
--- /dev/null
+++ b/coroutine_sema.h
@@ -0,0 +1,35 @@
+/* coroutine_sema.h - Simple semaphores for coroutine.{h,c}
+ *
+ * Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
+ * SPDX-Licence-Identifier: AGPL-3.0-or-later
+ */
+
+#ifndef _COROUTINE_SEMA_H_
+#define _COROUTINE_SEMA_H_
+
+/**
+ * A cr_sema_t is a fair unbounded[1] counting semaphore.
+ *
+ * [1]: Well, INT_MAX
+ */
+typedef struct sema_t cr_sema_t;
+
+/**
+ * Increment the semaphore,
+ *
+ * @blocks never
+ * @yields never
+ * @run_in anywhere (coroutine, sighandler)
+ */
+void cr_sema_signal(cr_sema_t *sema);
+
+/**
+ * Wait until the semaphore is >0, then decrement it.
+ *
+ * @blocks maybe
+ * @yields maybe
+ * @may_run_in coroutine
+ */
+void cr_sema_wait(cr_sema_t *sema);
+
+#endif /* _COROUTINE_SEMA_H_ */
diff --git a/netio_posix.c b/netio_posix.c
index 3cc00bb..46851f7 100644
--- a/netio_posix.c
+++ b/netio_posix.c
@@ -18,6 +18,7 @@
#include "netio.h"
#include "coroutine.h"
+#include "coroutine_sema.h"
/* I found the following post to be very helpful when writing this:
* http://davmac.org/davpage/linux/async-io.html */
@@ -29,16 +30,16 @@ static int sig_accept = 0;
#endif
struct netio_socket {
- int fd;
+ int fd;
#if CONFIG_NETIO_ISLINUX
- cid_t accept_waiters[CONFIG_NETIO_NUM_CONNS];
+ cr_sema_t accept_waiters;
#endif
};
static struct netio_socket socket_table[CONFIG_NETIO_NUM_PORTS] = {0};
static void handle_sig_io(int sig __attribute__ ((unused)), siginfo_t *info, void *ucontext __attribute__ ((unused))) {
- cr_unpause((cid_t)info->si_value.sival_int);
+ cr_unpause_from_sighandler((cid_t)info->si_value.sival_int);
}
#if CONFIG_NETIO_ISLINUX
@@ -49,12 +50,7 @@ static void handle_sig_accept(int sig __attribute__ ((unused)), siginfo_t *info,
sock = &socket_table[i];
if (!sock)
return;
- for (int i = 0; i < CONFIG_NETIO_NUM_CONNS; i++)
- if (sock->accept_waiters[i] > 0) {
- cr_unpause(sock->accept_waiters[i]);
- sock->accept_waiters[i] = 0;
- return;
- }
+ cr_sema_signal(&sock->accept_waiters);
}
#endif
@@ -108,7 +104,7 @@ int netio_listen(uint16_t port) {
memset(&addr, 0, sizeof addr);
addr.in.sin_family = AF_INET;
addr.in.sin_port = htons(port);
- sock->fd = socket(AF_INET, SOCK_STREAM | (CONFIG_NETIO_ISLINUX ? 0 : SOCK_NONBLOCK), 0);
+ sock->fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
if (sock->fd < 0)
error(1, errno, "socket");
if (setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0)
@@ -129,31 +125,29 @@ int netio_listen(uint16_t port) {
}
int netio_accept(int sock) {
-#if CONFIG_NETIO_ISLINUX
- int conn;
- for (int i = 0; i < CONFIG_NETIO_NUM_CONNS; i++)
- if (socket_table[sock].accept_waiters[i] == 0) {
- socket_table[sock].accept_waiters[i] = cr_getcid();
- break;
- }
- cr_pause_and_yield();
- conn = accept(socket_table[sock].fd, NULL, NULL);
- return conn < 0 ? -errno : conn;
-#else
/* AFAICT in pure POSIX there's no good way to do this that
- * isn't just busy-polling. */
+ * isn't just busy-polling.
+ *
+ * On Linux where we can get a signal to notify us when
+ * there's something to accept, we still do this non-blocking
+ * and check EAGAIN/EWOULDBLOCK in case the client timed out
+ * while waiting for us to accept(). */
for (;;) {
+#if CONFIG_NETIO_ISLINUX
+ cr_sema_wait(&sock->accept_waiters);
+#endif
int conn = accept(socket_table[sock].fd, NULL, NULL);
if (conn < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
+#if !CONFIG_NETIO_ISLINUX
cr_yield();
+#endif
continue;
}
return -errno;
}
return conn;
}
-#endif
}
ssize_t netio_read(int conn, void *buf, size_t count) {
@@ -185,7 +179,7 @@ ssize_t netio_write(int conn, void *buf, size_t goal) {
int r;
struct aiocb ctl_block = {
.aio_fildes = conn,
- .aio_buf = &buf[done],
+ .aio_buf = &(((uint8_t *)buf)[done]),
.aio_nbytes = goal-done,
.aio_sigevent = {
.sigev_notify = SIGEV_SIGNAL,
@@ -201,7 +195,7 @@ ssize_t netio_write(int conn, void *buf, size_t goal) {
while ((r = aio_error(&ctl_block)) == EINPROGRESS)
cr_pause_and_yield();
- if ((r) < 0)
+ if (r < 0)
return -abs(r);
done += aio_return(&ctl_block);
}