/* libcr/coroutine.c - Simple embeddable coroutine implementation * * Copyright (C) 2024 Luke T. Shumaker * SPDX-Licence-Identifier: AGPL-3.0-or-later */ #include #include /* for setjmp(), longjmp(), jmp_buf */ #include /* for uint8_t */ #include /* for printf(), fprintf(), stderr */ #include /* for aligned_alloc(), free() */ #include /* Configuration **************************************************************/ #include "config.h" #ifndef CONFIG_COROUTINE_DEFAULT_STACK_SIZE # error config.h must define CONFIG_COROUTINE_DEFAULT_STACK_SIZE #endif #ifndef CONFIG_COROUTINE_NUM # error config.h must define CONFIG_COROUTINE_NUM #endif #ifndef CONFIG_COROUTINE_MEASURE_STACK # error config.h must define CONFIG_COROUTINE_MEASURE_STACK #endif #ifndef CONFIG_COROUTINE_PROTECT_STACK # error config.h must define CONFIG_COROUTINE_PROTECT_STACK #endif #ifndef CONFIG_COROUTINE_DEBUG # error config.h must define CONFIG_COROUTINE_DEBUG #endif #ifndef CONFIG_COROUTINE_VALGRIND # error config.h must define CONFIG_COROUTINE_VALGRIND #endif #if CONFIG_COROUTINE_VALGRIND # include #endif /* Implementation *************************************************************/ /* * Portability notes: * * - It uses GCC `__attribute__`s, and the GNUC ({ ... }) statement * exprs extension. * * - It has a small bit of platform-specific code in the "platform * support" section. Other than this, it should be portable to * other platforms CPUs. It currently contains implementations for * __x86_64__ (assumes POSIX) and __arm__ (assumes bare-metal), and * should be fairly easy to add implementations for other platforms. * * - It uses setjmp()/longjmp() in "unsafe" ways. POSIX-2017 * longjmp(3p) says * * > If the most recent invocation of setjmp() with the * > corresponding jmp_buf ... or if the function containing the * > invocation of setjmp() has terminated execution in the interim, * > or if the invocation of setjmp() was within the scope of an * > identifier with variably modified type and execution has left * > that scope in the interim, the behavior is undefined. * * We use longjmp() both of these scenarios, but make it OK by using * call_with_stack() to manage the stack ourselves, assuming the * sole reason that longjmp() behavior is undefined in such cases is * because the stack that its saved stack-pointer points to is no * longer around. It seems absurd that an implementation would * choose to do something else, but I'm calling it out here because * you never know. * * Our assumptions seem to be valid for * x86_64-pc-linux-gnu/gcc-14.2.1/glibc-2.40 and * arm-none-eabi/gcc-14.1.0/newlib-4.4.0. * * Why not use , the now-deprecated (was in * POSIX.1-2001, is gone in POSIX-2008) predecesor to ? * It would let us do this without any assembly or unsafe * assumptions. Simply: because newlib does not provide it. But it * would let us avoid having an `sp` member in `struct coroutine`... * Maybe https://github.com/kaniini/libucontext ? Or building a * ucontext-lib abstraction on top of setjmp/longjmp? */ /* * Design decisions and notes: * * - Coroutines are launched with a stack that is filled with known * arbitrary values. Because all stack variables should be * initialized before they are read, this "shouldn't" make a * difference, but: (1) Initializing it to a known value allows us * to measure how much of the stack was written to, which is helpful * to tune stack sizes. (2) Leaving it uninitialized just gives me * the willies. * * - Because embedded programs should be adverse to using the heap, * CONFIG_COROUTINE_NUM is fixed, instead of having coroutine_add() * dynamically grow the coroutine_table as-needed. * * - On the flip-side, coroutine stacks are allocated on the heap * instead of having them be statically-allocated along with * coroutine_table. (1) This reduced the blast-area of damage for a * stack-overflow; and indeed if the end of the stack alignes with a * page-boundary then memory-protection can even detect the overflow * for us. (2) Having different-looking addresses for stack-area vs * static-area is handy for making things jump out at you when * debugging. (3) This can likely also improve things with being * page-aligned. * * - Coroutines must use cr_exit() instead of returning because if * they return then they will return to call_with_stack() in * coroutine_add() (not to after the longjmp() call in * coroutine_main()), and besides being * wrong-for-our-desired-flow-control, that's a stack location that * no longer exists. */ #define ALWAYS_INLINE inline __attribute__((always_inline)) /* platform support ***********************************************************/ /* As part of sbc-harness, this only really needs to support ARM-32, but being * able to run it on my x86-64 GNU/Linux laptop is useful for debugging. */ #define CR_PLAT_STACK_ALIGNMENT \ ({ __attribute__((aligned)) void fn(void) {}; __alignof__(fn); }) #if 0 { /* bracket to get Emacs indentation to work how I want */ #endif /*==================================================================== * Wrappers for setjmp()/longjmp() that do *not* save the * interrupt mask. */ #if __unix__ /* On a *NIX OS, we use signals as interrupts. POSIX leaves * it implementation-defined whether setjmp()/longjmp() save * the signal mask; while glibc does not save it, let's not * rely on that. */ #define cr_plat_setjmp(env) sigsetjmp(env, 0) #define cr_plat_longjmp(env, val) siglongjmp(env, val) #elif __NEWLIB__ /* newlib does not have sigsetjmp()/sigsetlongjmp(), but * setjmp()/longjmp() do not save the interrupt mask, * so we * can use them directly. */ #define cr_plat_setjmp(env) setjmp(env) #define cr_plat_longjmp(env, val) longjmp(env, val) #else #error unsupported platform (not __unix__, not __NEWLIB__) #endif /*==================================================================== * Interrupt management routines. */ #if __unix__ #include /* for sig*, SIG_* */ #include /* for pause() */ static inline void cr_plat_wait_for_interrupt(void) { pause(); } void _cr_plat_disable_interrupts(void) { sigset_t all; sigfillset(&all); sigprocmask(SIG_BLOCK, &all, NULL); } void _cr_plat_enable_interrupts(void) { sigset_t all; sigfillset(&all); sigprocmask(SIG_UNBLOCK, &all, NULL); } #elif __arm__ /* Assume bare-metal if !__unix__. */ static ALWAYS_INLINE void cr_plat_wait_for_interrupt(void) { asm volatile ("wfi":::"memory"); } void _cr_plat_disable_interrupts(void) { asm volatile ("cpsid i":::"memory"); } void _cr_plat_enable_interrupts(void) { asm volatile ("cpsie i":::"memory"); } #else #error unsupported platform (not __unix__, not bare-metal __arm__) #endif /*==================================================================== * Stack management routines. */ #if __arm__ #define CR_PLAT_STACK_GROWS_DOWNWARD 1 #if CONFIG_COROUTINE_MEASURE_STACK static ALWAYS_INLINE uintptr_t cr_plat_get_sp(void) { uintptr_t sp; asm volatile ("mov %0, sp":"=r"(sp)); return sp; } #endif static void cr_plat_call_with_stack(void *stack, cr_fn_t fn, void *args) { static void *saved_sp = NULL; /* str/ldr can only work with a "lo" register, which * sp is not, so we use r0 as an intermediate because * we're going to clobber it with args anyway. */ asm volatile ("mov r0, sp\n\t" /* [saved_sp = sp */ "str r0, %0\n\t" /* ] */ "mov sp, %1\n\t" /* [sp = stack] */ "mov r0, %3\n\t" /* [arg0 = args] */ "blx %2\n\t" /* [fn()] */ "ldr r0, %0\n\t" /* [sp = staved_sp */ "mov sp, r0" /* ] */ : : /* %0 */"m"(saved_sp), /* %1 */"r"(stack), /* %2 */"r"(fn), /* %3 */"r"(args) : "r0" ); } #elif __x86_64__ #define CR_PLAT_STACK_GROWS_DOWNWARD 1 #if CONFIG_COROUTINE_MEASURE_STACK static ALWAYS_INLINE uintptr_t cr_plat_get_sp(void) { uintptr_t sp; asm volatile ("movq %%rsp, %0":"=r"(sp)); return sp; } #endif static void cr_plat_call_with_stack(void *stack, cr_fn_t fn, void *args) { static void *saved_sp = NULL; asm volatile ("movq %%rsp , %0\n\t" /* saved_sp = sp */ "movq %1 , %%rsp\n\t" /* sp = stack */ "movq %3 , %%rdi\n\t" /* arg0 = args */ "call *%2\n\t" /* fn() */ "movq %0 , %%rsp" /* sp = saved_sp */ : : /* %0 */"m"(saved_sp), /* %1 */"r"(stack), /* %2 */"r"(fn), /* %3 */"r"(args) : "rdi" ); } #else #error unsupported platform (not __arm__, not __x86__) #endif #if 0 } #endif /* preprocessor macros ********************************************************/ /** Return `n` rounded up to the nearest multiple of `d` */ #define ROUND_UP(n, d) ( ( ((n)+(d)-1) / (d) ) * (d) ) #define ARRAY_LEN(arr) (sizeof(arr)/sizeof((arr)[0])) #define NEXT_POWER_OF_2(x) ((1ULL)<<((sizeof(unsigned long long)*8)-__builtin_clzll(x))) /* types **********************************************************************/ enum coroutine_state { CR_NONE = 0, /* this slot in the table is empty */ CR_INITIALIZING, /* running, before cr_begin() */ CR_RUNNING, /* running, after cr_begin() */ CR_PRE_RUNNABLE, /* running, after cr_unpause_from_intrhandler() * but before cr_pause() */ CR_RUNNABLE, /* not running, but runnable */ CR_PAUSED, /* not running, and not runnable */ }; struct coroutine { volatile enum coroutine_state state; jmp_buf env; #if CONFIG_COROUTINE_MEASURE_STACK /* We aught to be able to get this out of .env, but libc * authors insist on jmp_buf being opaque, glibc going as far * as to xor it with a secret ot obfuscate it! */ uintptr_t sp; #endif size_t stack_size; void *stack; #if CONFIG_COROUTINE_VALGRIND unsigned stack_id; #endif }; /* constants ******************************************************************/ const char *coroutine_state_strs[] = { [CR_NONE] = "CR_NONE", [CR_INITIALIZING] = "CR_INITIALIZING", [CR_RUNNING] = "CR_RUNNING", [CR_PRE_RUNNABLE] = "CR_PRE_RUNNABLE", [CR_RUNNABLE] = "CR_RUNNABLE", [CR_PAUSED] = "CR_PAUSED", }; #if CONFIG_COROUTINE_MEASURE_STACK || CONFIG_COROUTINE_PROTECT_STACK /* We just need a pattern that is unlikely to occur naturaly; this is * just a few bytes that I read from /dev/random. */ static const uint8_t stack_pattern[] = { 0xa1, 0x31, 0xe6, 0x07, 0x1f, 0x61, 0x20, 0x32, 0x4b, 0x14, 0xc4, 0xe0, 0xea, 0x62, 0x25, 0x63, }; #endif #if CONFIG_COROUTINE_PROTECT_STACK # define STACK_GUARD_SIZE \ ROUND_UP(sizeof(stack_pattern), CR_PLAT_STACK_ALIGNMENT) #else # define STACK_GUARD_SIZE 0 #endif /* global variables ***********************************************************/ static jmp_buf coroutine_add_env; static jmp_buf coroutine_main_env; /* * Invariants (and non-invariants): * * - exactly 0 or 1 coroutines have state CR_INITIALIZING * - exactly 0 or 1 coroutines have state CR_RUNNING * - if coroutine_running is non-zero, then * coroutine_table[coroutine_running-1] is the currently-running * coroutine * - the coroutine_running coroutine either has state CR_RUNNING or * CR_INITIALIZNG * - a coroutine having state CR_RUNNING does *not* imply that * coroutine_running points at that coroutine; if that coroutine is * in the middle of coroutine_add(), it coroutine_running points at * the CR_INITIALIZING child coroutine, while leaving the parent * coroutine as CR_RUNNING. * - a coroutine has state CR_RUNNABLE if and only if it is is in the * coroutine_ringbuf queue. */ static struct coroutine coroutine_table[CONFIG_COROUTINE_NUM] = {0}; static struct { /* tail == head means empty */ /* buf[tail] is the next thing to run */ /* buf[head] is where the next entry will go */ size_t head, tail; /* Having this be a power of 2 has 2 benefits: (a) the * compiler will optimize `%array_len` to &(array_len-1)`, (b) * we don't have to worry about funny wrap-around behavior * when head or tail overflow. */ cid_t buf[NEXT_POWER_OF_2(CONFIG_COROUTINE_NUM)]; } coroutine_ringbuf = {0}; static cid_t coroutine_running = 0; /* utility functions **********************************************************/ #define errorf(...) fprintf(stderr, "error: " __VA_ARGS__) #define infof(...) printf("info: " __VA_ARGS__) #if CONFIG_COROUTINE_DEBUG # define debugf(...) printf("dbg: " __VA_ARGS__) #else # define debugf(...) #endif #ifdef __GLIBC__ # define assertf(expr, ...) ({ \ if (!(expr)) { \ errorf("assertion: " __VA_ARGS__); \ __assert_fail(#expr, __FILE__, __LINE__, __func__); \ } \ }) #else # define assertf(expr, ...) assert(expr) #endif static inline const char* coroutine_state_str(enum coroutine_state state) { assert(state < ARRAY_LEN(coroutine_state_strs)); return coroutine_state_strs[state]; } static inline void coroutine_ringbuf_push(cid_t val) { coroutine_ringbuf.buf[coroutine_ringbuf.head++ % ARRAY_LEN(coroutine_ringbuf.buf)] = val; assert((coroutine_ringbuf.head % ARRAY_LEN(coroutine_ringbuf.buf)) != (coroutine_ringbuf.tail % ARRAY_LEN(coroutine_ringbuf.buf))); } static inline cid_t coroutine_ringbuf_pop(void) { if (coroutine_ringbuf.tail == coroutine_ringbuf.head) return 0; return coroutine_ringbuf.buf[coroutine_ringbuf.tail++ % ARRAY_LEN(coroutine_ringbuf.buf)]; } static inline void assert_cid(cid_t cid) { assert(cid > 0); assert(cid <= CONFIG_COROUTINE_NUM); #if CONFIG_COROUTINE_PROTECT_STACK assert(coroutine_table[cid-1].stack_size); uint8_t *stack = coroutine_table[cid-1].stack; assert(stack); for (size_t i = 0; i < STACK_GUARD_SIZE; i++) { size_t j = coroutine_table[cid-1].stack_size - (i+1); assert(stack[i] == stack_pattern[i%sizeof(stack_pattern)]); assert(stack[j] == stack_pattern[j%sizeof(stack_pattern)]); } #endif } #define assert_cid_state(cid, expr) do { \ assert_cid(cid); \ cid_t state = coroutine_table[(cid)-1].state; \ assert(expr); \ } while (0) /* coroutine_add() ************************************************************/ cid_t coroutine_add_with_stack_size(size_t stack_size, cr_fn_t fn, void *args) { static cid_t last_created = 0; cid_t parent = coroutine_running; if (parent) assert_cid_state(parent, state == CR_RUNNING); assert(stack_size); assert(fn); debugf("coroutine_add_with_stack_size(%zu, %p, %p)...\n", stack_size, fn, args); cid_t child; { size_t base = last_created; for (size_t shift = 0; shift < CONFIG_COROUTINE_NUM; shift++) { child = ((base + shift) % CONFIG_COROUTINE_NUM) + 1; if (coroutine_table[child-1].state == CR_NONE) goto found; } return 0; found: } debugf("...child=%zu\n", child); last_created = child; coroutine_table[child-1].stack_size = stack_size; coroutine_table[child-1].stack = aligned_alloc(CR_PLAT_STACK_ALIGNMENT, stack_size); #if CONFIG_COROUTINE_MEASURE_STACK || CONFIG_COROUTINE_PROTECT_STACK for (size_t i = 0; i < stack_size; i++) ((uint8_t *)coroutine_table[child-1].stack)[i] = stack_pattern[i%sizeof(stack_pattern)]; #endif #if CONFIG_COROUTINE_VALGRIND coroutine_table[child-1].stack_id = VALGRIND_STACK_REGISTER( coroutine_table[child-1].stack + STACK_GUARD_SIZE, coroutine_table[child-1].stack + stack_size - STACK_GUARD_SIZE); #endif coroutine_running = child; coroutine_table[child-1].state = CR_INITIALIZING; if (!cr_plat_setjmp(coroutine_add_env)) { /* point=a */ void *stack_base = coroutine_table[child-1].stack #if CR_PLAT_STACK_GROWS_DOWNWARD + stack_size - STACK_GUARD_SIZE #else + STACK_GUARD_SIZE #endif ; debugf("...stack =%p\n", coroutine_table[child-1].stack); debugf("...stack_base=%p\n", stack_base); /* run until cr_begin() */ cr_plat_call_with_stack(stack_base, fn, args); __builtin_unreachable(); /* should cr_begin() instead of returning */ } assert_cid_state(child, state == CR_RUNNABLE); if (parent) assert_cid_state(parent, state == CR_RUNNING); coroutine_running = parent; return child; } cid_t coroutine_add(cr_fn_t fn, void *args) { return coroutine_add_with_stack_size( CONFIG_COROUTINE_DEFAULT_STACK_SIZE, fn, args); } /* coroutine_main() ***********************************************************/ void coroutine_main(void) { debugf("coroutine_main()\n"); cr_disable_interrupts(); coroutine_running = 0; for (;;) { cid_t next; while ( !((next = coroutine_ringbuf_pop())) ) { /* No coroutines are runnable, wait for an interrupt * to change that. */ cr_enable_interrupts(); cr_plat_wait_for_interrupt(); cr_disable_interrupts(); } if (!cr_plat_setjmp(coroutine_main_env)) { /* point=b */ coroutine_running = next; coroutine_table[coroutine_running-1].state = CR_RUNNING; cr_plat_longjmp(coroutine_table[coroutine_running-1].env, 1); /* jump to point=c */ } /* This is where we jump to from cr_exit(), and from * nowhere else. */ assert_cid_state(coroutine_running, state == CR_NONE); #if CONFIG_COROUTINE_VALGRIND VALGRIND_STACK_DEREGISTER(coroutine_table[coroutine_running-1].stack_id); #endif free(coroutine_table[coroutine_running-1].stack); coroutine_table[coroutine_running-1] = (struct coroutine){0}; } } /* cr_*() *********************************************************************/ void cr_begin(void) { debugf("cid=%zu: cr_begin()\n", coroutine_running); assert_cid_state(coroutine_running, state == CR_INITIALIZING); coroutine_table[coroutine_running-1].state = CR_RUNNABLE; coroutine_ringbuf_push(coroutine_running); coroutine_table[coroutine_running-1].sp = cr_plat_get_sp(); if (!cr_plat_setjmp(coroutine_table[coroutine_running-1].env)) /* point=c1 */ cr_plat_longjmp(coroutine_add_env, 1); /* jump to point=a */ cr_enable_interrupts(); } static inline void _cr_yield() { cid_t next; while ( !((next = coroutine_ringbuf_pop())) ) { /* No coroutines are runnable, wait for an interrupt * to change that. */ cr_enable_interrupts(); cr_plat_wait_for_interrupt(); cr_disable_interrupts(); } if (next == coroutine_running) { coroutine_table[coroutine_running-1].state = CR_RUNNING; return; } coroutine_table[coroutine_running-1].sp = cr_plat_get_sp(); if (!cr_plat_setjmp(coroutine_table[coroutine_running-1].env)) { /* point=c2 */ coroutine_running = next; coroutine_table[coroutine_running-1].state = CR_RUNNING; cr_plat_longjmp(coroutine_table[coroutine_running-1].env, 1); /* jump to point=c */ } } void cr_yield(void) { debugf("cid=%zu: cr_yield()\n", coroutine_running); assert_cid_state(coroutine_running, state == CR_RUNNING); cr_disable_interrupts(); coroutine_table[coroutine_running-1].state = CR_RUNNABLE; coroutine_ringbuf_push(coroutine_running); _cr_yield(); cr_enable_interrupts(); } void cr_pause_and_yield(void) { debugf("cid=%zu: cr_pause_and_yield()\n", coroutine_running); assert_cid_state(coroutine_running, state == CR_RUNNING || state == CR_PRE_RUNNABLE); cr_disable_interrupts(); if (coroutine_table[coroutine_running-1].state == CR_PRE_RUNNABLE) { coroutine_table[coroutine_running-1].state = CR_RUNNABLE; coroutine_ringbuf_push(coroutine_running); } else coroutine_table[coroutine_running-1].state = CR_PAUSED; _cr_yield(); cr_enable_interrupts(); } void cr_exit(void) { debugf("cid=%zu: cr_exit()\n", coroutine_running); assert_cid_state(coroutine_running, state == CR_RUNNING); cr_disable_interrupts(); coroutine_table[coroutine_running-1].state = CR_NONE; cr_plat_longjmp(coroutine_main_env, 1); /* jump to point=b */ } void cr_unpause(cid_t cid) { debugf("cr_unpause(%zu)\n", cid); assert_cid_state(cid, state == CR_PAUSED); coroutine_table[cid-1].state = CR_RUNNABLE; coroutine_ringbuf_push(cid); } void cr_unpause_from_intrhandler(cid_t cid) { debugf("cr_unpause_from_intrhandler(%zu)\n", cid); assert_cid_state(cid, state == CR_RUNNING || state == CR_PAUSED); if (coroutine_table[cid-1].state == CR_RUNNING) { assert(cid == coroutine_running); debugf("... raced, deferring unpause\n"); coroutine_table[cid-1].state = CR_PRE_RUNNABLE; } else { debugf("... actual unpause\n"); coroutine_table[cid-1].state = CR_RUNNABLE; coroutine_ringbuf_push(cid); } } cid_t cr_getcid(void) { return coroutine_running; } /* cr_cid_info() **************************************************************/ #if CONFIG_COROUTINE_MEASURE_STACK void cr_cid_info(cid_t cid, struct cr_cid_info *ret) { assert_cid(cid); assert(ret); /* stack_cap */ ret->stack_cap = coroutine_table[cid-1].stack_size - 2*STACK_GUARD_SIZE; /* stack_max */ ret->stack_max = ret->stack_cap; uint8_t *stack = (uint8_t *)coroutine_table[cid-1].stack; for (;;) { size_t i = #if CR_PLAT_STACK_GROWS_DOWNWARD STACK_GUARD_SIZE + ret->stack_cap - ret->stack_max #else ret->stack_max - 1 - STACK_GUARD_SIZE #endif ; if (ret->stack_max == 0 || stack[i] != stack_pattern[i%sizeof(stack_pattern)]) break; ret->stack_max--; } /* stack_cur */ uintptr_t sp; if (cid == coroutine_running) sp = cr_plat_get_sp(); else sp = coroutine_table[cid-1].sp; assert(sp); uintptr_t sb = (uintptr_t)coroutine_table[cid-1].stack; #if CR_PLAT_STACK_GROWS_DOWNWARD ret->stack_cur = (sb - STACK_GUARD_SIZE) - sp; #else ret->stack_cur = sp - (sb + STACK_GUARD_SIZE); #endif } #endif /* CONFIG_COROUTINE_MEASURE_STACK */