/* coroutine.c - Simple embeddable coroutine implementation
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-Licence-Identifier: AGPL-3.0-or-later
 */

#include <stdint.h> /* for uint8_t */
#include <stdio.h>  /* for printf(), fprintf(), stderr */
#include <stdlib.h> /* for malloc(), free() */
#include <assert.h>
#include <setjmp.h>

#include "coroutine.h"

/* Configuration **************************************************************/

#define COROUTINE_NUM           5
#define COROUTINE_MEASURE_STACK 1
#define COROUTINE_PROTECT_STACK 1
#define COROUTINE_DEBUG         0

/* Implementation *************************************************************/

/*
 * Portability notes:
 *
 * - It uses GCC `__attribute__`s, though these can likely be easily
 *   swapped for other compilers.
 *
 * - It has a small bit of CPU-specific code (assembly, and definition
 *   of a STACK_GROWS_DOWNWARD={0,1} macro) in
 *   coroutine.c:call_with_stack().  Other than this, it should be
 *   portable to other CPUs.  It currently contains implementations
 *   for __x86_64__ and __arm__, and should be fairly easy to add
 *   implementations for other CPUs.
 *
 * - It uses setjmp()/longjmp() in "unsafe" ways.  POSIX-2017
 *   longjmp(3) says
 *
 *   > If the most recent invocation of setjmp() with the
 *   > corresponding jmp_buf ... or if the function containing the
 *   > invocation of setjmp() has terminated execution in the interim,
 *   > or if the invocation of setjmp() was within the scope of an
 *   > identifier with variably modified type and execution has left
 *   > that scope in the interim, the behavior is undefined.
 *
 *   We use longjmp() both of these scenarios, but make it OK by using
 *   call_with_stack() to manage the stack ourselves, assuming the
 *   sole reason that longjmp() behavior is undefined in such cases is
 *   because the stack that its saved stack-pointer points to is no
 *   longer around.  It seems absurd that an implementation would
 *   choose to do something else, but I'm calling it out here because
 *   you never know.
 *
 *   Note that setjmp()/longjmp() are defined in 3 places: in the libc
 *   (glibc/newlib), as GCC intrinsics, and the lower-level GCC
 *   __builtin_{setjmp,newlib} which the libc and intrinsic versions
 *   likely use.  Our assumptions seem to be valid for
 *   x86_64-pc-linux-gnu/gcc-14.2.1/glibc-2.40 and
 *   arm-none-eabi/gcc-14.1.0/newlib-4.4.0.
 *
 * Why not use <ucontext.h>, the now-deprecated (was in POSIX.1-2001,
 * is gone in POSIX-2008) predecesor to <setjmp.h>?  It would let us
 * do this without any assembly or unsafe assumptions.  Simply:
 * because newlib does not provide it.
 */

/*
 * Design decisions and notes:
 *
 * - Coroutines are launched with a stack stack that is filled with
 *   known values (zero or a pre-determined pattern, depending on the
 *   config above).  Because all stack variables should be initialized
 *   before they are read, this "shouldn't" make a difference, but:
 *   (1) Initializing it to a known value allows us to measure how
 *   much of the stack was written to, which is helpful to tune stack
 *   sizes.  (2) Leaving it uninitialized just gives me the willies.
 *
 * - Because embedded programs should be adverse to using the heap,
 *   COROUTINE_NUM is fixed, instead of having coroutine_add()
 *   dynamically grow the coroutine_table as-needed.
 *
 * - On the flip-side, coroutine stacks are allocated on the heap
 *   instead of having them be statically-allocated along with
 *   coroutine_table.  (1) This reduced the blast-area of damage for a
 *   stack-overflow; and indeed if the end of the stack alignes with a
 *   page-boundary then memory-protection can even detect the overflow
 *   for us.  (2) Having different-looking addresses for stack-area vs
 *   static-area is handy for making things jump out at you when
 *   debugging.  (3) This can likely also improve things with being
 *   page-aligned.
 *
 * - Coroutines must use cr_exit() instead of returning because if
 *   they return then they will return to call_with_stack() in
 *   coroutine_add() (not to after the longjmp() call in
 *   coroutine_main()), and besides being
 *   wrong-for-our-desired-flow-control, that's a stack location that
 *   no longer exists.
 *
 * Things to consider changing:
 *
 * - Consider having _cr_transition() go ahead and find the next
 *   coroutine to run and longjmp() direcly to it, instead of first
 *   jumping back to coroutine_main().  This could save a few cycles
 *   and a few bytes.
 */

#if COROUTINE_DEBUG
#	define debugf(...) printf("dbg: " __VA_ARGS__)
#else
#	define debugf(...)
#endif

static jmp_buf coroutine_add_env;
static jmp_buf coroutine_main_env;

enum coroutine_state {
	CR_NONE = 0,     /* this slot in the table is empty */
	CR_INITIALIZING, /* running, before cr_begin() */
	CR_RUNNING,      /* running, after cr_begin() */
	CR_RUNNABLE,     /* not running, but runnable */
	CR_PAUSED,       /* not running, and not runnable */ 
};

/*
 * Invariants (and non-invariants):
 *
 * - exactly 0 or 1 coroutines have state CR_INITIALIZING
 * - exactly 0 or 1 coroutines have state CR_RUNNING
 * - if coroutine_running is not zero, then
 *   coroutine_table[coroutine_running-1] is the currently-running
 *   coroutine
 * - the coroutine_running coroutine either has state CR_RUNNING or
 *   CR_INITIALIZNG
 * - a coroutine having state CR_RUNNING does *not* imply that
 *   coroutine_running points at that coroutine; if that coroutine is
 *   in the middle of coroutine_add(), it coroutine_running points at
 *   the CR_INITIALIZING child coroutine, while leaving the parent
 *   coroutine as CR_RUNNING.
 */

struct coroutine {
	enum coroutine_state state;
	jmp_buf              env;
	size_t               stack_size;
	void                *stack;
};

static struct coroutine coroutine_table[COROUTINE_NUM] = {0};
static cid_t            coroutine_running              = 0;

static void call_with_stack(void *stack, cr_fn_t fn, void *args) {
	static void *saved_sp = NULL;

	/* As part of sbc-harness, this only really needs to support
	 * ARM-32, but being able to run it on x86-64 is useful for
	 * debugging.  */
#if __x86_64__
#define STACK_GROWS_DOWNWARD 1
	asm volatile ("movq %%rsp , %0\n\t"    /* saved_sp = sp */
	              "movq %1    , %%rsp\n\t" /* sp = stack */
	              "movq %3    , %%rdi\n\t" /* arg0 = args */
	              "call *%2\n\t"           /* fn() */
	              "movq %0    , %%rsp"     /* sp = saved_sp */
	              :
	              : /* %0 */"m"(saved_sp),
	                /* %1 */"r"(stack),
	                /* %2 */"r"(fn),
	                /* %3 */"r"(args)
	              : "rdi"
	              );
#elif __arm__
#define STACK_GROWS_DOWNWARD 1
	/* str/ldr can only work with a "lo" register, which sp is
	 * not, so we use r0 as an intermediate because we're going to
	 * clobber it with args anyway.  */
	asm volatile ("mov r0, sp\n\t" /* [saved_sp = sp */
	              "str r0, %0\n\t" /* ] */
	              "mov sp, %1\n\t" /* [sp = stack] */
	              "mov r0, %3\n\t" /* [arg0 = args] */
	              "blx %2\n\t"     /* [fn()] */
	              "ldr r0, %0\n\t" /* [sp = staved_sp */
	              "mov sp, r0"     /* ] */
	              :
	              : /* %0 */"m"(saved_sp),
	                /* %1 */"r"(stack),
	                /* %2 */"r"(fn),
	                /* %3 */"r"(args)
	              : "r0"
	              );
#else
#  error unsupported architecture
#endif
}

#if COROUTINE_MEASURE_STACK || COROUTINE_PROTECT_STACK
/* We just need a pattern that is unlikely to occur naturaly; this is
 * just a few bytes that I read from /dev/random.  */
static const uint8_t const stack_pattern[] = {0x1e, 0x15, 0x16, 0x0a, 0xcc, 0x52, 0x7e, 0xb7};
#endif

#if COROUTINE_PROTECT_STACK
void assert_stack_protection(cid_t cid) {
	assert(coroutine_table[cid-1].stack_size);
	assert(coroutine_table[cid-1].stack);
	for (size_t i = 0; i < sizeof(stack_pattern); i++) {
		size_t j = coroutine_table[cid-1].stack_size - (i+1);
		assert(((uint8_t*)coroutine_table[cid-1].stack)[i] == stack_pattern[i]);
		assert(((uint8_t*)coroutine_table[cid-1].stack)[j] == stack_pattern[j%sizeof(stack_pattern)]);
	}
}
#else
#  define assert_stack_protection(cid) ((void)0)
#endif

#define assert_cid_state(cid, opstate) do {                     \
		assert((cid) > 0);                              \
		assert((cid) <= COROUTINE_NUM);                 \
		assert(coroutine_table[(cid)-1].state opstate); \
		assert_stack_protection(cid);                   \
	} while (0)

cid_t coroutine_add_with_stack_size(size_t stack_size, cr_fn_t fn, void *args) {
	static cid_t last_created = 0;
	cid_t parent = coroutine_running;

	if (parent)
		assert_cid_state(parent, == CR_RUNNING);
	assert(stack_size);
	assert(fn);
	debugf("coroutine_add_with_stack_size(%zu, %#p, %#p)...\n", stack_size, fn, args);

	cid_t child;
	{
		size_t idx_base = last_created;
		for (size_t idx_shift = 0; idx_shift < COROUTINE_NUM; idx_shift++) {
			child = ((idx_base + idx_shift) % COROUTINE_NUM) + 1;
			if (coroutine_table[child-1].state == CR_NONE)
				goto found;
		}
		return 0;
	found:
	}
	debugf("...child=%zu\n", child);

	last_created = child;

	coroutine_table[child-1].stack_size = stack_size;
	coroutine_table[child-1].stack = malloc(stack_size);
#if COROUTINE_MEASURE_STACK || COROUTINE_PROTECT_STACK
	for (size_t i = 0; i < stack_size; i++)
		((uint8_t*)coroutine_table[child-1].stack)[i] = stack_pattern[i%sizeof(stack_pattern)];
#endif

	coroutine_running = child; 
	coroutine_table[child-1].state = CR_INITIALIZING;
	if (!setjmp(coroutine_add_env)) { /* point=a */
		void *stack_base = coroutine_table[child-1].stack + (STACK_GROWS_DOWNWARD ? stack_size : 0);
#if COROUTINE_PROTECT_STACK
#  if STACK_GROWS_DOWNWARD
		stack_base -= sizeof(stack_pattern);
#  else
		stack_base += sizeof(stack_pattern);
#  endif
#endif
		debugf("...stack     =%#p\n", coroutine_table[child-1].stack);
		debugf("...stack_base=%#p\n", stack_base);
		/* run until cr_begin() */
		call_with_stack(stack_base, fn, args);
		assert(false); /* should cr_begin() instead of returning */
	}
	assert_cid_state(child, == CR_RUNNABLE);
	if (parent)
		assert_cid_state(parent, == CR_RUNNING);
	coroutine_running = parent;

	return child;
}

void coroutine_main(void) {
	debugf("coroutine_main()\n");
	bool ran;
	for (coroutine_running = 1;; coroutine_running = (coroutine_running%COROUTINE_NUM)+1) {
		if (coroutine_running == 1)
			ran = false;
		struct coroutine *cr = &coroutine_table[coroutine_running-1];
		if (cr->state == CR_RUNNABLE) {
			debugf("running cid=%zu...\n", coroutine_running);
			ran = true;
			cr->state = CR_RUNNING;
			if (!setjmp(coroutine_main_env)) { /* point=b */
				longjmp(cr->env, 1); /* jump to point=c */
				assert(false); /* should cr_exit() instead of returning */
			}
			assert_cid_state(coroutine_running, != CR_RUNNING);
			if (cr->state == CR_NONE) {
#if COROUTINE_MEASURE_STACK
				size_t stack_size = cr->stack_size - (COROUTINE_PROTECT_STACK ? 2*sizeof(stack_pattern) : 0);
				size_t stack_used = stack_size;
				for (;;) {
					size_t i = STACK_GROWS_DOWNWARD
						? (COROUTINE_PROTECT_STACK ? sizeof(stack_pattern) : 0) + stack_size - stack_used
						: stack_used - 1 - (COROUTINE_PROTECT_STACK ? sizeof(stack_pattern) : 0);
					if (stack_used == 0 || ((uint8_t*)cr->stack)[i] != stack_pattern[i%sizeof(stack_pattern)])
						break;
					stack_used--;
				}
				printf("info: cid=%zu: exited having used %zu B stack space\n", coroutine_running, stack_used);
#endif
				free(cr->stack);
				coroutine_table[coroutine_running-1] = (struct coroutine){0};
			}
		}
		if (coroutine_running == COROUTINE_NUM && !ran) {
			fprintf(stderr, "error: no runnable coroutines\n");
			return;
		}
	}
}

bool cr_begin(void) {
	assert_cid_state(coroutine_running, == CR_INITIALIZING);

	coroutine_table[coroutine_running-1].state = CR_RUNNABLE;
	if (!setjmp(coroutine_table[coroutine_running-1].env)) /* point=c1 */
		longjmp(coroutine_add_env, 1); /* jump to point=a */
}

static inline void _cr_transition(enum coroutine_state state) {
	assert_cid_state(coroutine_running, == CR_RUNNING);
	debugf("cid=%zu: transition %i->%i\n", coroutine_running, coroutine_table[coroutine_running-1].state, state);

	coroutine_table[coroutine_running-1].state = state;
	if (!setjmp(coroutine_table[coroutine_running-1].env)) /* point=c2 */
		longjmp(coroutine_main_env, 1); /* jump to point=b */
}

void cr_yield(void)           { _cr_transition(CR_RUNNABLE); }
void cr_pause_and_yield(void) { _cr_transition(CR_PAUSED); }

void cr_exit(void) {
	assert_cid_state(coroutine_running, == CR_RUNNING);
	debugf("cid=%zu: exit\n", coroutine_running);

	coroutine_table[coroutine_running-1].state = CR_NONE;
	longjmp(coroutine_main_env, 1); /* jump to point=b */
}

void cr_unpause(cid_t cid) {
	assert_cid_state(cid, == CR_PAUSED);
	debugf("cr_unpause(%zu)\n", cid);

	coroutine_table[cid-1].state = CR_RUNNABLE;
}

cid_t cr_getcid(void) {
	return coroutine_running;
}