diff options
-rw-r--r-- | coroutine.c | 53 | ||||
-rw-r--r-- | coroutine.h | 2 | ||||
-rw-r--r-- | coroutine_sema.c | 95 | ||||
-rw-r--r-- | coroutine_sema.h | 35 | ||||
-rw-r--r-- | netio_posix.c | 44 |
5 files changed, 189 insertions, 40 deletions
diff --git a/coroutine.c b/coroutine.c index f138fcf..12d4a25 100644 --- a/coroutine.c +++ b/coroutine.c @@ -154,10 +154,11 @@ enum coroutine_state { */ struct coroutine { - enum coroutine_state state; - jmp_buf env; - size_t stack_size; - void *stack; + volatile enum coroutine_state state; + volatile bool sig_unpause; + jmp_buf env; + size_t stack_size; + void *stack; }; static struct coroutine coroutine_table[CONFIG_COROUTINE_NUM] = {0}; @@ -213,8 +214,10 @@ static void call_with_stack(void *stack, cr_fn_t fn, void *args) { static const uint8_t stack_pattern[] = {0x1e, 0x15, 0x16, 0x0a, 0xcc, 0x52, 0x7e, 0xb7}; #endif +static void inline assert_cid(cid_t cid) { + assert(cid > 0); + assert(cid <= CONFIG_COROUTINE_NUM); #if CONFIG_COROUTINE_PROTECT_STACK -void assert_stack_protection(cid_t cid) { assert(coroutine_table[cid-1].stack_size); assert(coroutine_table[cid-1].stack); for (size_t i = 0; i < sizeof(stack_pattern); i++) { @@ -222,16 +225,12 @@ void assert_stack_protection(cid_t cid) { assert(((uint8_t*)coroutine_table[cid-1].stack)[i] == stack_pattern[i]); assert(((uint8_t*)coroutine_table[cid-1].stack)[j] == stack_pattern[j%sizeof(stack_pattern)]); } -} -#else -# define assert_stack_protection(cid) ((void)0) #endif +} #define assert_cid_state(cid, opstate) do { \ - assert((cid) > 0); \ - assert((cid) <= CONFIG_COROUTINE_NUM); \ + assert_cid(cid); \ assert(coroutine_table[(cid)-1].state opstate); \ - assert_stack_protection(cid); \ } while (0) cid_t coroutine_add_with_stack_size(size_t stack_size, cr_fn_t fn, void *args) { @@ -344,7 +343,6 @@ void cr_begin(void) { } static inline void _cr_transition(enum coroutine_state state) { - assert_cid_state(coroutine_running, == CR_RUNNING); debugf("cid=%zu: transition %i->%i\n", coroutine_running, coroutine_table[coroutine_running-1].state, state); coroutine_table[coroutine_running-1].state = state; @@ -352,8 +350,19 @@ static inline void _cr_transition(enum coroutine_state state) { longjmp(coroutine_main_env, 1); /* jump to point=b */ } -void cr_yield(void) { _cr_transition(CR_RUNNABLE); } -void cr_pause_and_yield(void) { _cr_transition(CR_PAUSED); } +void cr_yield(void) { + assert_cid_state(coroutine_running, == CR_RUNNING); + _cr_transition(CR_RUNNABLE); +} + +void cr_pause_and_yield(void) { + assert_cid_state(coroutine_running, == CR_RUNNING); + if (coroutine_table[cid-1].sig_unpause) + _cr_transition(CR_RUNNABLE); + else + _cr_transition(CR_PAUSED); + coroutine_table[cid-1].sig_unpause = false; +} void cr_exit(void) { assert_cid_state(coroutine_running, == CR_RUNNING); @@ -367,7 +376,21 @@ void cr_unpause(cid_t cid) { assert_cid_state(cid, == CR_PAUSED); debugf("cr_unpause(%zu)\n", cid); - coroutine_table[cid-1].state = CR_RUNNABLE; + coroutine_table[cid-1].sig_unpause = false; + coroutine_table[cid-1].state = CR_RUNNABLE; +} + +void cr_unpause_from_sighandler(cid_t cid) { + assert_cid(cid); + + switch (coroutine_table[cid-1].state) { + case CR_RUNNING: + coroutine_table[cid-1].sig_unpause = true; + case CR_PAUSED: + coroutine_table[cid-1].state = CR_RUNNABLE; + default: + assert(false); + } } cid_t cr_getcid(void) { diff --git a/coroutine.h b/coroutine.h index 4bdc4f8..4d83181 100644 --- a/coroutine.h +++ b/coroutine.h @@ -114,6 +114,8 @@ void cr_yield(void); void cr_pause_and_yield(void); /** cr_unpause() marks a coroutine as runnable that has previously marked itself as non-runnable with cr_pause_and_yield(). */ void cr_unpause(cid_t); +/** cr_unpause_from_sighandler() is like cr_unpause(), but safe to call from a signal handler that might race with the coroutine actually pausing itself. */ +void cr_unpause_from_sighandler(cid_t); /** cr_end() is a counterpart to cr_begin(), but is really just cr_exit(). */ #define cr_end cr_exit diff --git a/coroutine_sema.c b/coroutine_sema.c new file mode 100644 index 0000000..a8b5ec4 --- /dev/null +++ b/coroutine_sema.c @@ -0,0 +1,95 @@ +/* coroutine_sema.h - Simple semaphores for coroutine.{h,c} + * + * Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> + * SPDX-Licence-Identifier: AGPL-3.0-or-later + */ + +#include <assert.h> + +#include "coroutine_sema.h" + +struct cid_list { + cid_t val; + struct cid_list *next; +}; + +/* head->next->next->tail */ + +struct _cr_sema { + int cnt; + + struct cid_list *head, **tail; + /* locked indicates that a call from within a coroutine is is + * messing with ->{head,tail}, so a signal handler can't read + * it. */ + bool locked; +}; + +/** Drain the sema->{head,tail} list. Returns true if cr_getcid() was drained. */ +static inline bool drain(volatile cr_sema_t *sema) { + assert(!sema->locked); + cid_t self = cr_getcid(); + + enum drain_result { + DRAINING, + DRAINED_SELF, /* stopped because drained `self` */ + DRAINED_ALL, /* stopped because sema->head == NULL */ + DRAINED_SOME, /* stopped because sema->cnt == 0 */ + } state = DRAINING; + do { + sema->locked = true; + while (state == DRAINING) { + if (!sema->head) { + state = DRAINED_ALL; + } else if (!sema->cnt) { + state = DRAINED_SOME; + } else { + sema->cnt--; + cid_t cid = sema->head->val; + if (cid == self) + state = DRAINED_SELF; + else + cr_unpause(sema->head->val); + sema->head = sema->head->next; + if (!sema->head) + sema->tail = &sema->head; + } + } + sema->locked = false; + /* If there are still coroutines in sema->head, check + * that sema->cnt wasn't incremented between `if + * (!sema->cnt)` and `sema->locked = false`. */ + } while (state == DRAINED_SOME && cnt); + /* If state == DRAINED_SELF, then we better have been the last + * item in the list! */ + assert(state != DRAINED_SELF || !sema->head); + return state == DRAINED_SELF; +} + +void cr_sema_signal(volatile cr_sema_t *sema) { + sema->cnt++; + if (!sema->locked) + drain(); +} + +void cr_sema_wait(volatile cr_sema_t *sema) { + struct cid_list self = { + .val = cr_getcid(), + .next = NULL, + }; + + sema->locked = true; + if (!sema->tail) + sema->head = &self; + else + *(sema->tail) = &self; + sema->tail = &(self.next); + sema->locked = false; + + if (drain()) + /* DRAINED_SELF: (1) No need to pause+yield, (2) we + * better have been the last item in the list! */ + assert(!self.next); + else + cr_pause_and_yield(); +} diff --git a/coroutine_sema.h b/coroutine_sema.h new file mode 100644 index 0000000..57d5855 --- /dev/null +++ b/coroutine_sema.h @@ -0,0 +1,35 @@ +/* coroutine_sema.h - Simple semaphores for coroutine.{h,c} + * + * Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> + * SPDX-Licence-Identifier: AGPL-3.0-or-later + */ + +#ifndef _COROUTINE_SEMA_H_ +#define _COROUTINE_SEMA_H_ + +/** + * A cr_sema_t is a fair unbounded[1] counting semaphore. + * + * [1]: Well, INT_MAX + */ +typedef struct sema_t cr_sema_t; + +/** + * Increment the semaphore, + * + * @blocks never + * @yields never + * @run_in anywhere (coroutine, sighandler) + */ +void cr_sema_signal(cr_sema_t *sema); + +/** + * Wait until the semaphore is >0, then decrement it. + * + * @blocks maybe + * @yields maybe + * @may_run_in coroutine + */ +void cr_sema_wait(cr_sema_t *sema); + +#endif /* _COROUTINE_SEMA_H_ */ diff --git a/netio_posix.c b/netio_posix.c index 3cc00bb..46851f7 100644 --- a/netio_posix.c +++ b/netio_posix.c @@ -18,6 +18,7 @@ #include "netio.h" #include "coroutine.h" +#include "coroutine_sema.h" /* I found the following post to be very helpful when writing this: * http://davmac.org/davpage/linux/async-io.html */ @@ -29,16 +30,16 @@ static int sig_accept = 0; #endif struct netio_socket { - int fd; + int fd; #if CONFIG_NETIO_ISLINUX - cid_t accept_waiters[CONFIG_NETIO_NUM_CONNS]; + cr_sema_t accept_waiters; #endif }; static struct netio_socket socket_table[CONFIG_NETIO_NUM_PORTS] = {0}; static void handle_sig_io(int sig __attribute__ ((unused)), siginfo_t *info, void *ucontext __attribute__ ((unused))) { - cr_unpause((cid_t)info->si_value.sival_int); + cr_unpause_from_sighandler((cid_t)info->si_value.sival_int); } #if CONFIG_NETIO_ISLINUX @@ -49,12 +50,7 @@ static void handle_sig_accept(int sig __attribute__ ((unused)), siginfo_t *info, sock = &socket_table[i]; if (!sock) return; - for (int i = 0; i < CONFIG_NETIO_NUM_CONNS; i++) - if (sock->accept_waiters[i] > 0) { - cr_unpause(sock->accept_waiters[i]); - sock->accept_waiters[i] = 0; - return; - } + cr_sema_signal(&sock->accept_waiters); } #endif @@ -108,7 +104,7 @@ int netio_listen(uint16_t port) { memset(&addr, 0, sizeof addr); addr.in.sin_family = AF_INET; addr.in.sin_port = htons(port); - sock->fd = socket(AF_INET, SOCK_STREAM | (CONFIG_NETIO_ISLINUX ? 0 : SOCK_NONBLOCK), 0); + sock->fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); if (sock->fd < 0) error(1, errno, "socket"); if (setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0) @@ -129,31 +125,29 @@ int netio_listen(uint16_t port) { } int netio_accept(int sock) { -#if CONFIG_NETIO_ISLINUX - int conn; - for (int i = 0; i < CONFIG_NETIO_NUM_CONNS; i++) - if (socket_table[sock].accept_waiters[i] == 0) { - socket_table[sock].accept_waiters[i] = cr_getcid(); - break; - } - cr_pause_and_yield(); - conn = accept(socket_table[sock].fd, NULL, NULL); - return conn < 0 ? -errno : conn; -#else /* AFAICT in pure POSIX there's no good way to do this that - * isn't just busy-polling. */ + * isn't just busy-polling. + * + * On Linux where we can get a signal to notify us when + * there's something to accept, we still do this non-blocking + * and check EAGAIN/EWOULDBLOCK in case the client timed out + * while waiting for us to accept(). */ for (;;) { +#if CONFIG_NETIO_ISLINUX + cr_sema_wait(&sock->accept_waiters); +#endif int conn = accept(socket_table[sock].fd, NULL, NULL); if (conn < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { +#if !CONFIG_NETIO_ISLINUX cr_yield(); +#endif continue; } return -errno; } return conn; } -#endif } ssize_t netio_read(int conn, void *buf, size_t count) { @@ -185,7 +179,7 @@ ssize_t netio_write(int conn, void *buf, size_t goal) { int r; struct aiocb ctl_block = { .aio_fildes = conn, - .aio_buf = &buf[done], + .aio_buf = &(((uint8_t *)buf)[done]), .aio_nbytes = goal-done, .aio_sigevent = { .sigev_notify = SIGEV_SIGNAL, @@ -201,7 +195,7 @@ ssize_t netio_write(int conn, void *buf, size_t goal) { while ((r = aio_error(&ctl_block)) == EINPROGRESS) cr_pause_and_yield(); - if ((r) < 0) + if (r < 0) return -abs(r); done += aio_return(&ctl_block); } |