#define LINUX 1 #define NUM_SOCKETS 1 #define NUM_WORKERS 8 #define _GNU_SOURCE #include /* for struct aiocb, aio_read(), aio_write(), aio_error(), aio_return(), SIGEV_SIGNAL */ #include /* for htons() */ #include /* for errno, EAGAIN, EWOULDBLOCK, EINPROGRESS, EINVAL */ #include /* for error() */ #include /* for struct sockaddr_in */ #include /* for siginfo_t, struct sigaction, sigaction(), SIGRTMIN, SIGRTMAX, SA_SIGINFO */ #include /* for shutdown(), SHUT_RD, SHUT_WR, SHUT_RDWR */ #include /* for memset() */ #include /* for struct sockaddr, socket(), SOCK_* flags, setsockopt(), SOL_SOCKET, SO_REUSEADDR, bind(), listen(), accept() */ #if LINUX # include /* for fcntl(), F_SETFL, O_ASYNC, F_SETSIG */ #endif #include "netio.h" #include "coroutine.h" /* I found the following post to be very helpful when writing this: * http://davmac.org/davpage/linux/async-io.html */ static int sigs_allocated = 0; static int sig_io = 0; #if LINUX static int sig_accept = 0; #endif struct netio_socket { int fd; #if LINUX cid_t accept_waiters[NUM_WORKERS]; #endif }; static struct netio_socket socket_table[NUM_SOCKETS] = {0}; static void handle_sig_io(int sig __attribute__ ((unused)), siginfo_t *info, void *ucontext __attribute__ ((unused))) { cr_unpause((cid_t)info->si_value.sival_int); } #if LINUX static void handle_sig_accept(int sig __attribute__ ((unused)), siginfo_t *info, void *ucontext __attribute__ ((unused))) { struct netio_socket *sock = NULL; for (int i = 0; sock == NULL && i < NUM_SOCKETS; i++) if (info->si_fd == socket_table[i].fd) sock = &socket_table[i]; if (!sock) return; for (int i = 0; i < NUM_WORKERS; i++) if (sock->accept_waiters[i] > 0) { cr_unpause(sock->accept_waiters[i]); sock->accept_waiters[i] = 0; return; } } #endif static void _netio_init(void) { struct sigaction action; if (sig_io) return; sig_io = SIGRTMIN + (sigs_allocated++); if (sig_io > SIGRTMAX) error(1, 0, "SIGRTMAX exceeded"); memset(&action, 0, sizeof(action)); action.sa_flags = SA_SIGINFO; action.sa_sigaction = handle_sig_io; if (sigaction(sig_io, &action, NULL) < 0) error(1, errno, "sigaction"); #if LINUX sig_accept = SIGRTMIN + (sigs_allocated++); if (sig_accept > SIGRTMAX) error(1, 0, "SIGRTMAX exceeded"); memset(&action, 0, sizeof(action)); action.sa_flags = SA_SIGINFO; action.sa_sigaction = handle_sig_accept; if (sigaction(sig_accept, &action, NULL) < 0) error(1, errno, "sigaction"); #endif } int netio_listen(uint16_t port) { int handle; struct netio_socket *sock; union { struct sockaddr_in in; struct sockaddr gen; } addr; _netio_init(); /* Allocate a handle out of socket_table. */ handle = -1; for (int i = 0; handle < 0 && i < NUM_SOCKETS; i++) if (socket_table[i].fd == 0) handle = i; if (handle < 0) error(1, 0, "NUM_SOCKETS exceeded"); sock = &socket_table[handle]; /* Bind the socket. */ memset(&addr, 0, sizeof addr); addr.in.sin_family = AF_INET; addr.in.sin_port = htons(port); sock->fd = socket(AF_INET, SOCK_STREAM | (LINUX ? 0 : SOCK_NONBLOCK), 0); if (sock->fd < 0) error(1, errno, "socket"); if (setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0) error(1, errno, "setsockopt"); #if LINUX if (fcntl(sock->fd, F_SETFL, O_ASYNC) < 0) error(1, errno, "fcntl(F_SETFL)"); if (fcntl(sock->fd, F_SETSIG, sig_accept) < 0) error(1, errno, "fcntl(F_SETSIG)"); #endif if (bind(sock->fd, &addr.gen, sizeof addr) < 0) error(1, errno, "bind"); if (listen(sock->fd, NUM_WORKERS) < 0) error(1, errno, "listen"); /* Return. */ return handle; } int netio_accept(int sock) { #if LINUX int conn; for (int i = 0; i < NUM_WORKERS; i++) if (socket_table[sock].accept_waiters[i] == 0) { socket_table[sock].accept_waiters[i] = cr_getcid(); break; } cr_pause_and_yield(); conn = accept(socket_table[sock].fd, NULL, NULL); return conn < 0 ? -errno : conn; #else /* AFAICT in pure POSIX there's no good way to do this that * isn't just busy-polling. */ for (;;) { int conn = accept(socket_table[sock].fd, NULL, NULL); if (conn < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) { cr_yield(); continue; } return -errno; } return conn; } #endif } size_t netio_read(int conn, void *buf, size_t count) { int r; struct aiocb ctl_block = { .aio_fildes = conn, .aio_buf = buf, .aio_nbytes = count, .aio_sigevent = { .sigev_notify = SIGEV_SIGNAL, .sigev_signo = sig_io, .sigev_value = { .sival_int = (int)cr_getcid(), }, }, }; if (aio_read(&ctl_block) < 0) return -errno; while ((r = aio_error(&ctl_block)) == EINPROGRESS) cr_pause_and_yield(); return r ? -abs(r) : aio_return(&ctl_block); } size_t netio_write(int conn, void *buf, size_t count) { int r; struct aiocb ctl_block = { .aio_fildes = conn, .aio_buf = buf, .aio_nbytes = count, .aio_sigevent = { .sigev_notify = SIGEV_SIGNAL, .sigev_signo = sig_io, .sigev_value = { .sival_int = (int)cr_getcid(), }, }, }; if (aio_write(&ctl_block) < 0) return -errno; while ((r = aio_error(&ctl_block)) == EINPROGRESS) cr_pause_and_yield(); return r ? -abs(r) : aio_return(&ctl_block); } int netio_close(int conn, bool rd, bool wr) { int how; if (rd && wr) how = SHUT_RDWR; else if (rd && !wr) how = SHUT_RD; else if (!rd && wr) how = SHUT_WR; else return -EINVAL; return shutdown(conn, how) ? -errno : 0; }