diff options
Diffstat (limited to 'libcr/coroutine.c')
-rw-r--r-- | libcr/coroutine.c | 82 |
1 files changed, 49 insertions, 33 deletions
diff --git a/libcr/coroutine.c b/libcr/coroutine.c index c182fad..fe2183e 100644 --- a/libcr/coroutine.c +++ b/libcr/coroutine.c @@ -246,6 +246,7 @@ /** Return `n` rounded up to the nearest multiple of `d` */ #define ROUND_UP(n, d) ( ( ((n)+(d)-1) / (d) ) * (d) ) #define ARRAY_LEN(arr) (sizeof(arr)/sizeof((arr)[0])) +#define NEXT_POWER_OF_2(x) ((1ULL)<<((sizeof(unsigned long long)*8)-__builtin_clzll(x))) /* types **********************************************************************/ @@ -321,6 +322,17 @@ static jmp_buf coroutine_main_env; */ static struct coroutine coroutine_table[CONFIG_COROUTINE_NUM] = {0}; +static struct { + /* tail == head means empty */ + /* buf[tail] is the next thing to run */ + /* buf[head] is where the next entry will go */ + size_t head, tail; + /* Having this be a power of 2 has 2 benefits: (a) the + * compiler will optimize `%array_len` to &(array_len-1)`, (b) + * we don't have to worry about funny wrap-around behavior + * when head or tail overflow. */ + cid_t buf[NEXT_POWER_OF_2(CONFIG_COROUTINE_NUM)]; +} coroutine_ringbuf = {0}; static cid_t coroutine_running = 0; /* utility functions **********************************************************/ @@ -349,6 +361,18 @@ static inline const char* coroutine_state_str(enum coroutine_state state) { return coroutine_state_strs[state]; } +static inline void coroutine_ringbuf_push(cid_t val) { + coroutine_ringbuf.buf[coroutine_ringbuf.head++ % ARRAY_LEN(coroutine_ringbuf.buf)] = val; + assert((coroutine_ringbuf.head % ARRAY_LEN(coroutine_ringbuf.buf)) != + (coroutine_ringbuf.tail % ARRAY_LEN(coroutine_ringbuf.buf))); +} + +static inline cid_t coroutine_ringbuf_pop(void) { + if (coroutine_ringbuf.tail == coroutine_ringbuf.head) + return 0; + return coroutine_ringbuf.buf[coroutine_ringbuf.tail++ % ARRAY_LEN(coroutine_ringbuf.buf)]; +} + static inline void assert_cid(cid_t cid) { assert(cid > 0); assert(cid <= CONFIG_COROUTINE_NUM); @@ -426,6 +450,7 @@ cid_t coroutine_add_with_stack_size(size_t stack_size, cr_fn_t fn, void *args) { assert(false); /* should cr_begin() instead of returning */ } assert_cid_state(child, state == CR_RUNNABLE); + coroutine_ringbuf_push(child); if (parent) assert_cid_state(parent, state == CR_RUNNING); coroutine_running = parent; @@ -440,27 +465,20 @@ cid_t coroutine_add(cr_fn_t fn, void *args) { /* coroutine_main() ***********************************************************/ -static inline cid_t next_coroutine() { - for (cid_t next = (coroutine_running%CONFIG_COROUTINE_NUM)+1; - next != coroutine_running; - next = (next%CONFIG_COROUTINE_NUM)+1) { - if (coroutine_table[next-1].state == CR_RUNNABLE) - return next; - } - return 0; -} - void coroutine_main(void) { debugf("coroutine_main()\n"); cr_disable_interrupts(); coroutine_running = 0; for (;;) { - cid_t next = next_coroutine(); - if (!next) { + cid_t next; + while ( !((next = coroutine_ringbuf_pop())) ) { + /* No coroutines are runnable, wait for an interrupt + * to change that. */ cr_enable_interrupts(); - errorf("no coroutines\n"); - return; + cr_plat_wait_for_interrupt(); + cr_disable_interrupts(); } + if (!cr_plat_setjmp(coroutine_main_env)) { /* point=b */ coroutine_running = next; coroutine_table[coroutine_running-1].state = CR_RUNNING; @@ -487,21 +505,9 @@ void cr_begin(void) { cr_enable_interrupts(); } -static inline void _cr_transition(enum coroutine_state state) { - coroutine_table[coroutine_running-1].state = state; - +static inline void _cr_yield() { cid_t next; - for (;;) { - next = next_coroutine(); - if (next) - /* Switch to `next`. */ - break; - if (coroutine_table[coroutine_running-1].state == CR_RUNNABLE) { - /* No other coroutine is runnable, don't - * transition after all. */ - coroutine_table[coroutine_running-1].state = CR_RUNNING; - return; - } + while ( !((next = coroutine_ringbuf_pop())) ) { /* No coroutines are runnable, wait for an interrupt * to change that. */ cr_enable_interrupts(); @@ -509,6 +515,11 @@ static inline void _cr_transition(enum coroutine_state state) { cr_disable_interrupts(); } + if (next == coroutine_running) { + coroutine_table[coroutine_running-1].state = CR_RUNNING; + return; + } + coroutine_table[coroutine_running-1].sp = cr_plat_get_sp(); if (!cr_plat_setjmp(coroutine_table[coroutine_running-1].env)) { /* point=c2 */ coroutine_running = next; @@ -522,7 +533,8 @@ void cr_yield(void) { assert_cid_state(coroutine_running, state == CR_RUNNING); cr_disable_interrupts(); - _cr_transition(CR_RUNNABLE); + coroutine_table[coroutine_running-1].state = CR_RUNNABLE; + _cr_yield(); cr_enable_interrupts(); } @@ -531,10 +543,12 @@ void cr_pause_and_yield(void) { assert_cid_state(coroutine_running, state == CR_RUNNING || state == CR_PRE_RUNNABLE); cr_disable_interrupts(); - if (coroutine_table[coroutine_running-1].state == CR_PRE_RUNNABLE) - _cr_transition(CR_RUNNABLE); - else - _cr_transition(CR_PAUSED); + if (coroutine_table[coroutine_running-1].state == CR_PRE_RUNNABLE) { + coroutine_table[coroutine_running-1].state = CR_RUNNABLE; + coroutine_ringbuf_push(coroutine_running); + } else + coroutine_table[coroutine_running-1].state = CR_PAUSED; + _cr_yield(); cr_enable_interrupts(); } @@ -552,6 +566,7 @@ void cr_unpause(cid_t cid) { assert_cid_state(cid, state == CR_PAUSED); coroutine_table[cid-1].state = CR_RUNNABLE; + coroutine_ringbuf_push(cid); } void cr_unpause_from_intrhandler(cid_t cid) { @@ -565,6 +580,7 @@ void cr_unpause_from_intrhandler(cid_t cid) { } else { debugf("... actual unpause\n"); coroutine_table[cid-1].state = CR_RUNNABLE; + coroutine_ringbuf_push(cid); } } |