/* netio_posix.c - netio implementation for POSIX-ish systems * (actually uses a few GNU extensions) * * Copyright (C) 2024 Luke T. Shumaker * SPDX-Licence-Identifier: AGPL-3.0-or-later */ #define _GNU_SOURCE /* for pthread_sigqueue(3gnu) */ /* misc */ #include /* for assert() */ #include /* for errno, EAGAIN, EINVAL */ #include /* for error(3gnu) */ #include /* for abs(), shutdown(), SHUT_RD, SHUT_WR, SHUT_RDWR */ #include /* for read(), write() */ /* net */ #include /* for htons(3p) */ #include /* for struct sockaddr_in */ #include /* for struct sockaddr, socket(), SOCK_* flags, setsockopt(), SOL_SOCKET, SO_REUSEADDR, bind(), listen(), accept() */ /* async */ #include /* for pthread_* */ #include /* for siginfo_t, struct sigaction, enum sigval, sigaction(), SIGRTMIN, SIGRTMAX, SA_SIGINFO */ #include #include /* configuration **************************************************************/ #include "config.h" #ifndef CONFIG_NETIO_NUM_CONNS # error config.h must define CONFIG_NETIO_NUM_CONNS #endif /* common *********************************************************************/ #define UNUSED(name) /* name __attribute__ ((unused)) */ static int sig_io = 0; static void handle_sig_io(int UNUSED(sig), siginfo_t *info, void *UNUSED(ucontext)) { cr_unpause_from_intrhandler((cid_t)info->si_value.sival_int); } static void _netio_init(void) { struct sigaction action = {0}; if (sig_io) return; sig_io = SIGRTMIN; if (sig_io > SIGRTMAX) error(1, 0, "SIGRTMAX exceeded"); action.sa_flags = SA_SIGINFO; action.sa_sigaction = handle_sig_io; if (sigaction(sig_io, &action, NULL) < 0) error(1, errno, "sigaction"); } #define WAKE_COROUTINE(args) do { \ int r; \ union sigval val = {0}; \ val.sival_int = (int)((args)->cr_coroutine); \ do { \ r = pthread_sigqueue((args)->cr_thread, sig_io, val); \ assert(r == 0 || r == EAGAIN); \ } while (r == EAGAIN); \ } while (0) #define RUN_PTHREAD(fn, args) do { \ pthread_t thread; \ int r; \ r = pthread_create(&thread, NULL, fn, args); \ if (r) \ return -abs(r); \ cr_pause_and_yield(); \ r = pthread_join(thread, NULL); \ if (r) \ return -abs(r); \ } while (0) /* listen() *******************************************************************/ int netio_listen(uint16_t port) { int sockfd; union { struct sockaddr_in in; struct sockaddr gen; } addr = { 0 }; _netio_init(); addr.in.sin_family = AF_INET; addr.in.sin_port = htons(port); sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd < 0) error(1, errno, "socket"); if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0) error(1, errno, "setsockopt"); if (bind(sockfd, &addr.gen, sizeof addr) < 0) error(1, errno, "bind"); if (listen(sockfd, CONFIG_NETIO_NUM_CONNS) < 0) error(1, errno, "listen"); return sockfd; } /* accept() *******************************************************************/ struct _pthread_accept_args { pthread_t cr_thread; cid_t cr_coroutine; int sockfd; int *ret; }; void *_pthread_accept(void *_args) { struct _pthread_accept_args *args = _args; *(args->ret) = accept(args->sockfd, NULL, NULL); if (*(args->ret) < 0) *(args->ret) = -errno; WAKE_COROUTINE(args); return NULL; }; int netio_accept(int sock) { int ret; struct _pthread_accept_args args = { .cr_thread = pthread_self(), .cr_coroutine = cr_getcid(), .sockfd = sock, .ret = &ret, }; RUN_PTHREAD(_pthread_accept, &args); return ret; } /* read() *********************************************************************/ struct _pthread_read_args { pthread_t cr_thread; cid_t cr_coroutine; int connfd; void *buf; size_t count; ssize_t *ret; }; void *_pthread_read(void *_args) { struct _pthread_read_args *args = _args; *(args->ret) = read(args->connfd, args->buf, args->count); if (*(args->ret) < 0) *(args->ret) = -errno; WAKE_COROUTINE(args); return NULL; }; ssize_t netio_read(int conn, void *buf, size_t count) { ssize_t ret; struct _pthread_read_args args = { .cr_thread = pthread_self(), .cr_coroutine = cr_getcid(), .connfd = conn, .buf = buf, .count = count, .ret = &ret, }; RUN_PTHREAD(_pthread_read, &args); return ret; } /* write() ********************************************************************/ struct _pthread_write_args { pthread_t cr_thread; cid_t cr_coroutine; int connfd; void *buf; size_t count; ssize_t *ret; }; void *_pthread_write(void *_args) { struct _pthread_read_args *args = _args; size_t done = 0; while (done < args->count) { ssize_t r = write(args->connfd, args->buf, args->count); if (r < 0) { *(args->ret) = -errno; break; } done += r; } if (done == args->count) *(args->ret) = done; WAKE_COROUTINE(args); return NULL; }; ssize_t netio_write(int conn, void *buf, size_t count) { ssize_t ret; struct _pthread_write_args args = { .cr_thread = pthread_self(), .cr_coroutine = cr_getcid(), .connfd = conn, .buf = buf, .count = count, .ret = &ret, }; RUN_PTHREAD(_pthread_write, &args); return ret; } /* close() ********************************************************************/ int netio_close(int conn, bool rd, bool wr) { int how; if (rd && wr) how = SHUT_RDWR; else if (rd && !wr) how = SHUT_RD; else if (!rd && wr) how = SHUT_WR; else return -EINVAL; return shutdown(conn, how) ? -errno : 0; }