/* libcr/coroutine.c - Simple embeddable coroutine implementation
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include <stdint.h> /* for uint8_t */
#include <stdlib.h> /* for aligned_alloc(), free() */
#include <string.h> /* for strncpy(), memset() */

#include <libmisc/assert.h>
#include <libmisc/macro.h>

#define LOG_NAME COROUTINE
#include <libmisc/log.h>

#include <libcr/coroutine.h>
#undef COROUTINE

/* Configuration **************************************************************/

#include "config.h"

#ifndef CONFIG_COROUTINE_DEFAULT_STACK_SIZE
	#error config.h must define CONFIG_COROUTINE_DEFAULT_STACK_SIZE (non-negative integer)
#endif
#ifndef CONFIG_COROUTINE_NAME_LEN
	#error config.h must define CONFIG_COROUTINE_NAME_LEN (non-negative integer)
#endif
#ifndef CONFIG_COROUTINE_NUM
	#error config.h must define CONFIG_COROUTINE_NUM (non-negative integer)
#endif
#ifndef CONFIG_COROUTINE_MEASURE_STACK
	#error config.h must define CONFIG_COROUTINE_MEASURE_STACK (bool)
#endif
#ifndef CONFIG_COROUTINE_PROTECT_STACK
	#error config.h must define CONFIG_COROUTINE_PROTECT_STACK (bool)
#endif
#ifndef CONFIG_COROUTINE_DEBUG
	#error config.h must define CONFIG_COROUTINE_DEBUG (bool)
#endif
#ifndef CONFIG_COROUTINE_VALGRIND
	#error config.h must define CONFIG_COROUTINE_VALGRIND (bool)
#endif
#ifndef CONFIG_COROUTINE_GDB
	#error config.h must define CONFIG_COROUTINE_GDB (bool)
#endif

/* Implementation *************************************************************/

#if CONFIG_COROUTINE_VALGRIND
	#include <valgrind/valgrind.h>
#endif

/*
 * Portability notes:
 *
 * - It uses GCC `gnu::` attributes, and the GNUC `({ ... })`
 *   statement exprs extension.
 *
 * - It has a small bit of platform-specific code in the "platform
 *   support" section.  Other than this, it should be portable to
 *   other platforms CPUs.  It currently contains implementations for
 *   __unix__ and __ARM_EABI__ "operating systems" on __x86_64__ and
 *   __ARM_ARCH_6M__ CPUs, and should be fairly easy to add
 *   implementations for other platforms.
 *
 * - It uses setjmp()/longjmp() in "unsafe" ways.  POSIX-2017
 *   longjmp(3p) says
 *
 *   > If the most recent invocation of setjmp() with the
 *   > corresponding jmp_buf ... or if the function containing the
 *   > invocation of setjmp() has terminated execution in the interim,
 *   > or if the invocation of setjmp() was within the scope of an
 *   > identifier with variably modified type and execution has left
 *   > that scope in the interim, the behavior is undefined.
 *
 *   We use longjmp() both of these scenarios, but make it OK by using
 *   call_with_stack() to manage the stack ourselves, assuming the
 *   sole reason that longjmp() behavior is undefined in such cases is
 *   because the stack that its saved stack-pointer points to is no
 *   longer around.  It seems absurd that an implementation would
 *   choose to do something else, but I'm calling it out here because
 *   you never know.
 *
 *   Our assumptions seem to be valid for
 *   x86_64-pc-linux-gnu/gcc-14.2.1/glibc-2.40 and
 *   arm-none-eabi/gcc-14.1.0/newlib-4.4.0.
 *
 *   Why not use <ucontext.h>, the now-deprecated (was in
 *   POSIX.1-2001, is gone in POSIX-2008) predecesor to <setjmp.h>?
 *   It would let us do this without any assembly or unsafe
 *   assumptions.  Simply: because newlib does not provide it.  But it
 *   would let us avoid having an `sp` member in `struct coroutine`...
 *   Maybe https://github.com/kaniini/libucontext ?  Or building a
 *   ucontext-lib abstraction on top of setjmp/longjmp?
 */

/*
 * Design decisions and notes:
 *
 * - Coroutines are launched with a stack that is filled with known
 *   arbitrary values.  Because all stack variables should be
 *   initialized before they are read, this "shouldn't" make a
 *   difference, but: (1) Initializing it to a known value allows us
 *   to measure how much of the stack was written to, which is helpful
 *   to tune stack sizes.  (2) Leaving it uninitialized just gives me
 *   the willies.
 *
 * - Because embedded programs should be adverse to using the heap,
 *   CONFIG_COROUTINE_NUM is fixed, instead of having coroutine_add()
 *   dynamically grow the coroutine_table as-needed.
 *
 * - On the flip-side, coroutine stacks are allocated on the heap
 *   instead of having them be statically-allocated along with
 *   coroutine_table.  (1) This reduced the blast-area of damage for a
 *   stack-overflow; and indeed if the end of the stack alignes with a
 *   page-boundary then memory-protection can even detect the overflow
 *   for us.  (2) Having different-looking addresses for stack-area vs
 *   static-area is handy for making things jump out at you when
 *   debugging.  (3) This can likely also improve things with being
 *   page-aligned.
 *
 * - Coroutines must use cr_exit() instead of returning because if
 *   they return then they will return to call_with_stack() in
 *   coroutine_add() (not to after the longjmp() call in
 *   coroutine_main()), and besides being
 *   wrong-for-our-desired-flow-control, that's a stack location that
 *   no longer exists.
 */

/* platform support ***********************************************************/

/* As part of sbc-harness, this only really needs to support ARM-32, but being
 * able to run it on my x86-64 GNU/Linux laptop is useful for debugging.  */

#define CR_PLAT_STACK_ALIGNMENT \
	({ [[gnu::aligned]] void fn(void) {}; __alignof__(fn); })

#if 0
{ /* bracket to get Emacs indentation to work how I want */
#endif

/*====================================================================
 * Interrupt management routines.  */
#if __unix__
	#include <signal.h> /* for sig*, SIG* */

	/* For a signal to be *in* the mask means that the signal is
	 * *blocked*.  */

	#define _CR_SIG_SENTINEL SIGURG
	#if CONFIG_COROUTINE_GDB
	#define _CR_SIG_GDB      SIGWINCH
	#endif

	bool cr_plat_is_in_intrhandler(void) {
		sigset_t cur_mask;
		sigfillset(&cur_mask);
		sigprocmask(0, NULL, &cur_mask);
		if (sigismember(&cur_mask, _CR_SIG_SENTINEL))
			/* Interrupts are disabled, so we cannot be in
			 * an interrupt handler.  */
			return false;
		for (int sig = SIGRTMIN; sig <= SIGRTMAX; sig++)
			if (sigismember(&cur_mask, sig))
				return true;
		return false;
	}
	static inline bool _cr_plat_are_interrupts_enabled(void) {
		assert(!cr_plat_is_in_intrhandler());
		sigset_t cur_mask;
		sigfillset(&cur_mask);
		sigprocmask(0, NULL, &cur_mask);
		return !sigismember(&cur_mask, _CR_SIG_SENTINEL);
	}

	static inline void cr_plat_wait_for_interrupt(void) {
		assert(!cr_plat_is_in_intrhandler());
		assert(!_cr_plat_are_interrupts_enabled());
		sigset_t set;
		sigemptyset(&set);
		sigsuspend(&set);

		sigfillset(&set);
		sigprocmask(SIG_SETMASK, &set, NULL);
	}
	bool _cr_plat_save_and_disable_interrupts(void) {
		assert(!cr_plat_is_in_intrhandler());
		sigset_t all, old;
		sigfillset(&all);
		sigprocmask(SIG_SETMASK, &all, &old);
		return !sigismember(&old, _CR_SIG_SENTINEL);
	}
	void _cr_plat_enable_interrupts(void) {
		assert(!cr_plat_is_in_intrhandler());
		assert(!_cr_plat_are_interrupts_enabled());
		sigset_t zero;
		sigemptyset(&zero);
		sigprocmask(SIG_SETMASK, &zero, NULL);
	}
	#if CONFIG_COROUTINE_GDB
	static void _cr_gdb_intrhandler(int LM_UNUSED(sig)) {}
	#endif
	static void cr_plat_init(void) {
	#if CONFIG_COROUTINE_GDB
		int r;
		struct sigaction action = {
			.sa_handler = _cr_gdb_intrhandler,
		};
		r = sigaction(_CR_SIG_GDB, &action, NULL);
		assert(r == 0);
	#endif
	}
#elif __ARM_ARCH_6M__ && __ARM_EABI__
	bool cr_plat_is_in_intrhandler(void) {
		uint32_t isr_number;
		asm volatile ("mrs %0, ipsr"
		              : /* %0 */"=l"(isr_number)
		              );
		return isr_number != 0;
	}
	LM_ALWAYS_INLINE static bool _cr_plat_are_interrupts_enabled(void) {
		assert(!cr_plat_is_in_intrhandler());
		uint32_t primask;
		asm volatile ("mrs %0, PRIMASK"
		              : /* %0 */"=l"(primask)
		              );
		return primask == 0;
	}

	LM_ALWAYS_INLINE static void cr_plat_wait_for_interrupt(void) {
		assert(!cr_plat_is_in_intrhandler());
		assert(!_cr_plat_are_interrupts_enabled());
		asm volatile ("wfi\n"
		              "cpsie i\n"
		              "isb\n"
		              "cpsid i"
		              :::"memory");
	}
	bool _cr_plat_save_and_disable_interrupts(void) {
		assert(!cr_plat_is_in_intrhandler());
		bool were_enabled = _cr_plat_are_interrupts_enabled();
		asm volatile ("cpsid i");
		return were_enabled;
	}
	void _cr_plat_enable_interrupts(void) {
		assert(!cr_plat_is_in_intrhandler());
		assert(!_cr_plat_are_interrupts_enabled());
		asm volatile ("cpsie i");
	}
	static void cr_plat_init(void) {}
#else
	#error unsupported platform (not __unix__, not __ARM_ARCH_6M__ && __ARM_EABI__)
#endif

/*====================================================================
 * Stack management routines.  */
#if __ARM_ARCH_6M__
	#define CR_PLAT_STACK_GROWS_DOWNWARD 1

	#if CONFIG_COROUTINE_MEASURE_STACK
	LM_ALWAYS_INLINE static uintptr_t cr_plat_get_sp(void) {
		uintptr_t sp;
		asm volatile ("mov %0, sp":"=r"(sp));
		return sp;
	}
	#endif

	static void cr_plat_call_with_stack(void *stack,
	                                    cr_fn_t fn, void *args) {
		static void *saved_sp = NULL;
		/* str/ldr can only work with a "lo" register, which
		 * sp is not, so we use r0 as an intermediate because
		 * we're going to clobber it with args anyway.  */
		asm volatile ("mov r0, sp\n\t" /* [saved_sp = sp */
		              "str r0, %0\n\t" /* ] */
		              "mov sp, %1\n\t" /* [sp = stack] */
		              "mov r0, %3\n\t" /* [arg0 = args] */
		              "blx %2\n\t"     /* [fn()] */
		              "ldr r0, %0\n\t" /* [sp = saved_sp */
		              "mov sp, r0"     /* ] */
		              :
		              : /* %0 */"m"(saved_sp),
		                /* %1 */"l"(stack),
		                /* %2 */"l"(fn),
		                /* %3 */"l"(args)
		              : "r0"
		              );
	}
#elif __x86_64__
	#define CR_PLAT_STACK_GROWS_DOWNWARD 1

	#if CONFIG_COROUTINE_MEASURE_STACK
	LM_ALWAYS_INLINE static uintptr_t cr_plat_get_sp(void)  {
		uintptr_t sp;
		asm volatile ("movq %%rsp, %0":"=r"(sp));
		return sp;
	}
	#endif

	static void cr_plat_call_with_stack(void *stack,
	                                    cr_fn_t fn, void *args) {
		static void *saved_sp = NULL;
		asm volatile ("movq %%rsp , %0\n\t"    /* saved_sp = sp */
		              "movq %1    , %%rsp\n\t" /* sp = stack */
		              "movq %3    , %%rdi\n\t" /* arg0 = args */
		              "call *%2\n\t"           /* fn() */
		              "movq %0    , %%rsp"     /* sp = saved_sp */
		              :
		              : /* %0 */"m"(saved_sp),
		                /* %1 */"r"(stack),
		                /* %2 */"r"(fn),
		                /* %3 */"r"(args)
		              : "rdi"
		              );
	}
#else
	#error unsupported CPU (not __ARM_ARCH_6M__, not __x86_64__)
#endif

/*====================================================================
 * Wrappers for setjmp()/longjmp() that:
 *  1. Allow us to inspect the buffer.
 *  2. Do *not* save the interrupt mask.
 */
	#include <setjmp.h> /* for setjmp(), longjmp(), jmp_buf */
	typedef struct {
		jmp_buf                        raw;
	#if CONFIG_COROUTINE_MEASURE_STACK
		/* We aught to be able to get sp out of the raw
		 * `jmp_buf`, but libc authors insist on jmp_buf being
		 * opaque, glibc going as far as to xor it with a
		 * secret to obfuscate it!  */
		uintptr_t                      sp;
	#endif
	} cr_plat_jmp_buf;
	static void _cr_plat_setjmp_pre(cr_plat_jmp_buf *env [[gnu::unused]]) {
	#if CONFIG_COROUTINE_MEASURE_STACK
		env->sp = cr_plat_get_sp();
	#endif
	}
	#if CONFIG_COROUTINE_MEASURE_STACK
	static uintptr_t cr_plat_setjmp_get_sp(cr_plat_jmp_buf *env) { return env->sp; }
	#endif
	/* cr_plat_setjmp *NEEDS* to be a preprocessor macro rather
	 * than a real function, because [[gnu::returns_twice]]
	 * doesn't work.
	 * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=117469 */
#if __unix__
	/* On __unix__, we use POSIX real-time signals as interrupts.
	 * POSIX leaves it implementation-defined whether
	 * setjmp()/longjmp() save the signal mask; while glibc does
	 * not save it, let's not rely on that.  */
	#define cr_plat_setjmp(env) ({ _cr_plat_setjmp_pre(env); sigsetjmp((env)->raw, 0); })
	[[noreturn]] static void cr_plat_longjmp(cr_plat_jmp_buf *env, int val) { siglongjmp(env->raw, val); }
#elif __NEWLIB__
	/* newlib does not have sigsetjmp()/sigsetlongjmp(), but
	 * setjmp()/longjmp() do not save the interrupt mask, so we
	 * can use them directly.  */
	#define cr_plat_setjmp(env) ({ _cr_plat_setjmp_pre(env); setjmp((env)->raw); })
	[[noreturn]] static void cr_plat_longjmp(cr_plat_jmp_buf *env, int val) { longjmp(env->raw, val); }
#else
	#error unsupported platform (not __unix__, not __NEWLIB__)
#endif

#if 0
}
#endif

/* types **********************************************************************/

enum coroutine_state {
	CR_NONE = 0,     /* this slot in the table is empty */
	CR_INITIALIZING, /* running, before cr_begin() */
	CR_RUNNING,      /* running, after cr_begin() */
	CR_RUNNABLE,     /* not running, but runnable */
	CR_PAUSED,       /* not running, and not runnable */
};

struct coroutine {
	volatile enum coroutine_state  state;
	cr_plat_jmp_buf                env;
	size_t                         stack_size;
	void                          *stack;
#if CONFIG_COROUTINE_VALGRIND
	unsigned                       stack_id;
#endif
	char                           name[CONFIG_COROUTINE_NAME_LEN];
};

/* constants ******************************************************************/

const char *coroutine_state_strs[] = {
	[CR_NONE]         = "CR_NONE",
	[CR_INITIALIZING] = "CR_INITIALIZING",
	[CR_RUNNING]      = "CR_RUNNING",
	[CR_RUNNABLE]     = "CR_RUNNABLE",
	[CR_PAUSED]       = "CR_PAUSED",
};

#if CONFIG_COROUTINE_MEASURE_STACK || CONFIG_COROUTINE_PROTECT_STACK
/* We just need a pattern that is unlikely to occur naturaly; this is
 * just a few bytes that I read from /dev/random.  */
static const uint8_t stack_pattern[] = {
	0xa1, 0x31, 0xe6, 0x07, 0x1f, 0x61, 0x20, 0x32,
	0x4b, 0x14, 0xc4, 0xe0, 0xea, 0x62, 0x25, 0x63,
};
#endif
#if CONFIG_COROUTINE_PROTECT_STACK
	#define CR_STACK_GUARD_SIZE \
		LM_ROUND_UP(sizeof(stack_pattern), CR_PLAT_STACK_ALIGNMENT)
#else
	#define CR_STACK_GUARD_SIZE 0
#endif

/* global variables ***********************************************************/

static bool            coroutine_initialized = false;
static cr_plat_jmp_buf coroutine_add_env;
static cr_plat_jmp_buf coroutine_main_env;
#if CONFIG_COROUTINE_GDB
static cr_plat_jmp_buf coroutine_gdb_env;
#endif

/*
 * Invariants (and non-invariants):
 *
 * - exactly 0 or 1 coroutines have state CR_INITIALIZING
 * - exactly 0 or 1 coroutines have state CR_RUNNING
 * - if coroutine_running is non-zero, then
 *   coroutine_table[coroutine_running-1] is the currently-running
 *   coroutine
 * - the coroutine_running coroutine either has state CR_RUNNING or
 *   CR_INITIALIZNG
 * - a coroutine having state CR_RUNNING does *not* imply that
 *   coroutine_running points at that coroutine; if that coroutine is
 *   in the middle of coroutine_add(), it coroutine_running points at
 *   the CR_INITIALIZING child coroutine, while leaving the parent
 *   coroutine as CR_RUNNING.
 * - a coroutine has state CR_RUNNABLE if and only if it is is in the
 *   coroutine_ringbuf queue.
 */

static struct coroutine coroutine_table[CONFIG_COROUTINE_NUM] = {0};
static struct {
	/* tail == head means empty */
	/* buf[tail] is the next thing to run */
	/* buf[head] is where the next entry will go */
	size_t  head, tail;
	/* Having this be a power of 2 has 2 benefits: (a) the
	 * compiler will optimize `%array_len` to &(array_len-1)`, (b)
	 * we don't have to worry about funny wrap-around behavior
	 * when head or tail overflow.  */
	cid_t   buf[LM_NEXT_POWER_OF_2(CONFIG_COROUTINE_NUM)];
}                       coroutine_ringbuf                     = {0};
static cid_t            coroutine_running                     = 0;
static size_t           coroutine_cnt                         = 0;

/* utility functions **********************************************************/

static inline const char* coroutine_state_str(enum coroutine_state state) {
	assert(state < LM_ARRAY_LEN(coroutine_state_strs));
	return coroutine_state_strs[state];
}

static inline void coroutine_ringbuf_push(cid_t val) {
	coroutine_ringbuf.buf[coroutine_ringbuf.head++ % LM_ARRAY_LEN(coroutine_ringbuf.buf)] = val;
	assert((coroutine_ringbuf.head % LM_ARRAY_LEN(coroutine_ringbuf.buf)) !=
	       (coroutine_ringbuf.tail % LM_ARRAY_LEN(coroutine_ringbuf.buf)));
}

static inline cid_t coroutine_ringbuf_pop(void) {
	if (coroutine_ringbuf.tail == coroutine_ringbuf.head)
		return 0;
	return coroutine_ringbuf.buf[coroutine_ringbuf.tail++ % LM_ARRAY_LEN(coroutine_ringbuf.buf)];
}

#if CONFIG_COROUTINE_GDB
LM_NEVER_INLINE void cr_gdb_breakpoint(void) {
	/* Prevent the call from being optimized away. */
	asm ("");
}
LM_NEVER_INLINE void cr_gdb_readjmp(cr_plat_jmp_buf *env) {
	if (!cr_plat_setjmp(&coroutine_gdb_env))
		cr_plat_longjmp(env, 2);
}
#define cr_setjmp(env) ({                                               \
			int val = cr_plat_setjmp(env);                  \
			if (val == 2) {                                 \
				cr_gdb_breakpoint();                    \
				cr_plat_longjmp(&coroutine_gdb_env, 1); \
			}                                               \
			val;                                            \
		})
#else
#define cr_setjmp(env)  cr_plat_setjmp(env)
#endif
#define cr_longjmp(env) cr_plat_longjmp(env, 1)

static inline void assert_cid(cid_t cid) {
	assert(cid > 0);
	assert(cid <= CONFIG_COROUTINE_NUM);
#if CONFIG_COROUTINE_PROTECT_STACK
	assert(coroutine_table[cid-1].stack_size);
	uint8_t *stack = coroutine_table[cid-1].stack;
	assert(stack);
	for (size_t i = 0; i < CR_STACK_GUARD_SIZE; i++) {
		size_t j = coroutine_table[cid-1].stack_size - (i+1);
		assert(stack[i] == stack_pattern[i%sizeof(stack_pattern)]);
		assert(stack[j] == stack_pattern[j%sizeof(stack_pattern)]);
	}
#endif
}

#define assert_cid_state(cid, expr) do {                      \
		assert_cid(cid);                              \
		cid_t state = coroutine_table[(cid)-1].state; \
		assert(expr);                                 \
	} while (0)


/* coroutine_add() ************************************************************/

cid_t coroutine_add_with_stack_size(size_t stack_size,
                                    const char *name,
                                    cr_fn_t fn, void *args) {
	static cid_t last_created = 0;
	cid_t parent = coroutine_running;

	if (parent)
		assert_cid_state(parent, state == CR_RUNNING);
	assert(stack_size);
	assert(fn);
	debugf("coroutine_add_with_stack_size(%zu, \"%s\", %p, %p)...",
	       stack_size, name, fn, args);

	if (!coroutine_initialized) {
		cr_plat_init();
		coroutine_initialized = true;
	}

	cid_t child;
	{
		size_t base = last_created;
		for (size_t shift = 0; shift < CONFIG_COROUTINE_NUM; shift++) {
			child = ((base + shift) % CONFIG_COROUTINE_NUM) + 1;
			if (coroutine_table[child-1].state == CR_NONE)
				goto found;
		}
		return 0;
	found:
	}
	debugf("...child=%zu", child);

	last_created = child;

	if (name)
		strncpy(coroutine_table[child-1].name, name, sizeof(coroutine_table[child-1].name));
	else
		memset(coroutine_table[child-1].name, 0, sizeof(coroutine_table[child-1].name));

	coroutine_table[child-1].stack_size = stack_size;
	infof("allocing \"%s\" stack with size %zu", name, stack_size);
	coroutine_table[child-1].stack =
		aligned_alloc(CR_PLAT_STACK_ALIGNMENT, stack_size);
	infof("...done");
#if CONFIG_COROUTINE_MEASURE_STACK || CONFIG_COROUTINE_PROTECT_STACK
	for (size_t i = 0; i < stack_size; i++)
		((uint8_t *)coroutine_table[child-1].stack)[i] =
			stack_pattern[i%sizeof(stack_pattern)];
#endif
#if CONFIG_COROUTINE_VALGRIND
	coroutine_table[child-1].stack_id = VALGRIND_STACK_REGISTER(
		coroutine_table[child-1].stack + CR_STACK_GUARD_SIZE,
		coroutine_table[child-1].stack + stack_size - CR_STACK_GUARD_SIZE);
#endif

	coroutine_running = child;
	coroutine_table[child-1].state = CR_INITIALIZING;
	coroutine_cnt++;
	if (!cr_setjmp(&coroutine_add_env)) { /* point=a */
		void *stack_base = coroutine_table[child-1].stack
#if CR_PLAT_STACK_GROWS_DOWNWARD
			+ stack_size
			- CR_STACK_GUARD_SIZE
#else
			+ CR_STACK_GUARD_SIZE
#endif
			;
		debugf("...stack     =%p", coroutine_table[child-1].stack);
		debugf("...stack_base=%p", stack_base);
		/* run until cr_begin() */
		cr_plat_call_with_stack(stack_base, fn, args);
		assert_notreached("should cr_begin() instead of returning");
	}
	assert_cid_state(child, state == CR_RUNNABLE);
	if (parent)
		assert_cid_state(parent, state == CR_RUNNING);
	/* Restore interrupts because cr_begin() disables interrupts
	 * before the context switch.  XXX: This assumes that
	 * interrupts were enabled when _add() was called, which we
	 * didn't actually check.  */
	cr_restore_interrupts(true);
	coroutine_running = parent;

	return child;
}

cid_t coroutine_add(const char *name, cr_fn_t fn, void *args) {
	return coroutine_add_with_stack_size(
		CONFIG_COROUTINE_DEFAULT_STACK_SIZE, name, fn, args);
}

/* coroutine_main() ***********************************************************/

void coroutine_main(void) {
	debugf("coroutine_main()");
	if (!coroutine_initialized) {
		cr_plat_init();
		coroutine_initialized = true;
	}
	bool saved = cr_save_and_disable_interrupts();
	assert(saved);
	assert(!cr_plat_is_in_intrhandler());
	coroutine_running = 0;
#if CONFIG_COROUTINE_GDB
	/* Some pointless call to prevent cr_gdb_readjmp() from
	 * getting pruned out of the firmware image.  */
	if (coroutine_table[0].state != CR_NONE)
		cr_gdb_readjmp(&coroutine_table[0].env);
#endif
	while (coroutine_cnt) {
		cid_t next;
		while ( !((next = coroutine_ringbuf_pop())) ) {
			/* No coroutines are runnable, wait for an interrupt
			 * to change that.  */
			cr_plat_wait_for_interrupt();
		}

		if (!cr_setjmp(&coroutine_main_env)) { /* point=b */
			coroutine_running = next;
			coroutine_table[coroutine_running-1].state = CR_RUNNING;
			cr_longjmp(&coroutine_table[coroutine_running-1].env); /* jump to point=c */
		}
		/* This is where we jump to from cr_exit(), and from
		 * nowhere else.  */
		assert_cid_state(coroutine_running, state == CR_NONE);
#if CONFIG_COROUTINE_VALGRIND
		VALGRIND_STACK_DEREGISTER(coroutine_table[coroutine_running-1].stack_id);
#endif
		free(coroutine_table[coroutine_running-1].stack);
		coroutine_table[coroutine_running-1] = (struct coroutine){0};
		coroutine_cnt--;
	}
	cr_restore_interrupts(saved);
}

/* cr_*() *********************************************************************/

void cr_begin(void) {
	debugf("cid=%zu: cr_begin()", coroutine_running);
	assert_cid_state(coroutine_running, state == CR_INITIALIZING);

	bool saved = cr_save_and_disable_interrupts();
	coroutine_table[coroutine_running-1].state = CR_RUNNABLE;
	coroutine_ringbuf_push(coroutine_running);
	if (!cr_setjmp(&coroutine_table[coroutine_running-1].env)) /* point=c1 */
		cr_longjmp(&coroutine_add_env); /* jump to point=a */
	cr_restore_interrupts(saved);
}

static inline void _cr_yield() {
	cid_t next;
	while ( !((next = coroutine_ringbuf_pop())) ) {
		/* No coroutines are runnable, wait for an interrupt
		 * to change that.  */
		cr_plat_wait_for_interrupt();
	}

	if (next == coroutine_running) {
		coroutine_table[coroutine_running-1].state = CR_RUNNING;
		return;
	}

	if (!cr_setjmp(&coroutine_table[coroutine_running-1].env)) { /* point=c2 */
		coroutine_running = next;
		coroutine_table[coroutine_running-1].state = CR_RUNNING;
		cr_longjmp(&coroutine_table[coroutine_running-1].env); /* jump to point=c */
	}
}

void cr_yield(void) {
	debugf("cid=%zu: cr_yield()", coroutine_running);
	assert(!cr_plat_is_in_intrhandler());
	assert_cid_state(coroutine_running, state == CR_RUNNING);

	bool saved = cr_save_and_disable_interrupts();
	coroutine_table[coroutine_running-1].state = CR_RUNNABLE;
	coroutine_ringbuf_push(coroutine_running);
	_cr_yield();
	cr_restore_interrupts(saved);
}

void cr_pause_and_yield(void) {
	debugf("cid=%zu: cr_pause_and_yield()", coroutine_running);
	assert(!cr_plat_is_in_intrhandler());
	assert_cid_state(coroutine_running, state == CR_RUNNING);

	bool saved = cr_save_and_disable_interrupts();
	coroutine_table[coroutine_running-1].state = CR_PAUSED;
	_cr_yield();
	cr_restore_interrupts(saved);
}

[[noreturn]] void cr_exit(void) {
	debugf("cid=%zu: cr_exit()", coroutine_running);
	assert(!cr_plat_is_in_intrhandler());
	assert_cid_state(coroutine_running, state == CR_RUNNING);

	(void)cr_save_and_disable_interrupts();
	coroutine_table[coroutine_running-1].state = CR_NONE;
	cr_longjmp(&coroutine_main_env); /* jump to point=b */
}

static void _cr_unpause(cid_t cid) {
	assert_cid_state(cid, state == CR_PAUSED);

	coroutine_table[cid-1].state = CR_RUNNABLE;
	coroutine_ringbuf_push(cid);
}

void cr_unpause(cid_t cid) {
	debugf("cr_unpause(%zu)", cid);
	assert(!cr_plat_is_in_intrhandler());
	assert_cid_state(coroutine_running, state == CR_RUNNING);

	bool saved = cr_save_and_disable_interrupts();
	_cr_unpause(cid);
	cr_restore_interrupts(saved);
}

void cr_unpause_from_intrhandler(cid_t cid) {
	debugf("cr_unpause_from_intrhandler(%zu)", cid);
	assert(cr_plat_is_in_intrhandler());

	_cr_unpause(cid);
}

cid_t cr_getcid(void) {
	assert(!cr_plat_is_in_intrhandler());
	assert_cid_state(coroutine_running, state == CR_RUNNING);
	return coroutine_running;
}

void cr_assert_in_coroutine(void) {
	assert(!cr_plat_is_in_intrhandler());
	assert_cid_state(coroutine_running, state == CR_RUNNING);
}

void cr_assert_in_intrhandler(void) {
	assert(cr_plat_is_in_intrhandler());
}

/* cr_cid_info() **************************************************************/

#if CONFIG_COROUTINE_MEASURE_STACK

void cr_cid_info(cid_t cid, struct cr_cid_info *ret) {
	assert_cid(cid);
	assert(ret);

	/* stack_cap */
	ret->stack_cap = coroutine_table[cid-1].stack_size - 2*CR_STACK_GUARD_SIZE;

	/* stack_max */
	ret->stack_max = ret->stack_cap;
	uint8_t *stack = (uint8_t *)coroutine_table[cid-1].stack;
	for (;;) {
		size_t i =
#if CR_PLAT_STACK_GROWS_DOWNWARD
			CR_STACK_GUARD_SIZE + ret->stack_cap - ret->stack_max
#else
			ret->stack_max - 1 - CR_STACK_GUARD_SIZE
#endif
			;
		if (ret->stack_max == 0 ||
		    stack[i] != stack_pattern[i%sizeof(stack_pattern)])
			break;
		ret->stack_max--;
	}

	/* stack_cur */
	uintptr_t sp;
	if (cid == coroutine_running)
		sp = cr_plat_get_sp();
	else if (coroutine_table[cid-1].state == CR_RUNNING)
		sp = cr_plat_setjmp_get_sp(&coroutine_add_env);
	else
		sp = cr_plat_setjmp_get_sp(&coroutine_table[cid-1].env);
	assert(sp);
	uintptr_t sb = (uintptr_t)coroutine_table[cid-1].stack;
#if CR_PLAT_STACK_GROWS_DOWNWARD
	ret->stack_cur = (sb - CR_STACK_GUARD_SIZE) - sp;
#else
	ret->stack_cur = sp - (sb + CR_STACK_GUARD_SIZE);
#endif
}

#endif /* CONFIG_COROUTINE_MEASURE_STACK */