/* coroutine.c - Simple coroutine and request/response implementation * * Copyright (C) 2024 Luke T. Shumaker * SPDX-Licence-Identifier: AGPL-3.0-or-later */ #include /* for uint8_t */ #include /* for fprintf(), stderr */ #include /* for calloc(), free() */ #include #include #include "cfg_limits.h" #include "coroutine.h" enum coroutine_state { CR_NONE = 0, CR_INITIALIZING, CR_RUNNING, CR_RUNNABLE, CR_PAUSED, }; struct coroutine { enum coroutine_state state; jmp_buf env; size_t stack_size; void *stack; }; static struct coroutine coroutine_table[COROUTINE_NUM] = {0}; static cid_t coroutine_running = 0; static jmp_buf coroutine_add_env; static jmp_buf coroutine_main_env; void coroutine_init(void) {} static void call_with_stack(void *stack, cr_fn_t fn, void *args) { static void *saved_sp = NULL; /* This only really exists for running on ARM-32, but being * able to run it on x86-64 is useful for debugging. */ #if __x86_64__ #define STACK_GROWS_DOWNWARD 1 asm volatile ("movq %%rsp , %0\n\t" /* saved_sp = sp */ "movq %1 , %%rsp\n\t" /* sp = stack */ "movq %3 , %%rdi\n\t" /* arg0 = args */ "call *%2\n\t" /* fn() */ "movq %0 , %%rsp" /* sp = saved_sp */ : : /* %0 */"m"(saved_sp), /* %1 */"r"(stack), /* %2 */"r"(fn), /* %3 */"r"(args) : ); #elif __arm__ #define STACK_GROWS_DOWNWARD 1 asm volatile ("mov r0, coroutine_table[_cur_cid].arg" "mov _saved_stack sp" "mov sp, coroutine_table[_cur_cid].stack" "bl coroutine_table[_cur_cid].fn" "mov _saved_stack sp"); #else #error unsupported architecture #endif } cid_t coroutine_add(cr_fn_t fn, void *args) { static cid_t last_created = 0; assert(coroutine_running == 0 || coroutine_table[coroutine_running-1].state == CR_RUNNING); cid_t child; { size_t idx_base = last_created; for (size_t idx_shift = 0; idx_shift < COROUTINE_NUM; idx_shift++) { child = ((idx_base + idx_shift) % COROUTINE_NUM) + 1; if (coroutine_table[child-1].state == CR_NONE) goto found; } return 0; found: } last_created = child; coroutine_table[child-1].stack_size = COROUTINE_STACK_SIZE; coroutine_table[child-1].stack = calloc(1, coroutine_table[child-1].stack_size); cid_t parent = coroutine_running; assert(parent == 0 || coroutine_table[parent-1].state == CR_RUNNING); coroutine_running = child; coroutine_table[child-1].state = CR_INITIALIZING; if (!setjmp(coroutine_add_env)) { /* point=a */ /* run until cr_begin() */ call_with_stack(coroutine_table[child-1].stack + (STACK_GROWS_DOWNWARD ? coroutine_table[child-1].stack_size : 0), fn, args); /* cr_begin() calls longjmp(point=a); if fn returns * then that means it didn't call cr_begin(), which is * wrong. */ assert(false); } assert(coroutine_table[child-1].state == CR_RUNNABLE); assert(parent == 0 || coroutine_table[parent-1].state == CR_RUNNING); coroutine_running = parent; return child; } void coroutine_main(void) { bool ran; for (coroutine_running = 1;; coroutine_running = (coroutine_running%COROUTINE_NUM)+1) { if (coroutine_running == 1) ran = false; if (coroutine_table[coroutine_running-1].state == CR_RUNNABLE) { ran = true; coroutine_table[coroutine_running-1].state = CR_RUNNING; if (!setjmp(coroutine_main_env)) { /* point=b */ longjmp(coroutine_table[coroutine_running-1].env, 1); /* jump to point=c */ /* Consider returning to be the same as cr_exit(). */ coroutine_table[coroutine_running-1].state = CR_NONE; } if (coroutine_table[coroutine_running-1].state == CR_NONE) { free(coroutine_table[coroutine_running-1].stack); coroutine_table[coroutine_running-1] = (struct coroutine){0}; } } if (coroutine_running == COROUTINE_NUM+1 && !ran) { fprintf(stderr, "error: no runnable coroutines\n"); return; } } } bool cr_begin(void) { assert(coroutine_table[coroutine_running-1].state == CR_INITIALIZING); coroutine_table[coroutine_running-1].state = CR_RUNNABLE; if (!setjmp(coroutine_table[coroutine_running-1].env)) /* point=c1 */ longjmp(coroutine_add_env, 1); /* jump to point=a */ } static inline void _cr_transition(enum coroutine_state state) { assert(coroutine_table[coroutine_running-1].state == CR_RUNNING); coroutine_table[coroutine_running-1].state = state; if (state == CR_NONE || !setjmp(coroutine_table[coroutine_running-1].env)) /* point=c2 */ longjmp(coroutine_main_env, 1); /* jump to point=b */ } void cr_exit(void) { _cr_transition(CR_NONE); } void cr_yield(void) { _cr_transition(CR_RUNNABLE); } void cr_pause_and_yield(void) { _cr_transition(CR_PAUSED); } void cr_unpause(cid_t cid) { assert(coroutine_table[cid-1].state == CR_PAUSED); coroutine_table[cid-1].state = CR_RUNNABLE; } cid_t cr_getcid(void) { return coroutine_running; }