summaryrefslogtreecommitdiff
path: root/coroutine.c
blob: 65971d42fdb786e5e2a66bdb896cae7a61518b05 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/* coroutine.c - Simple coroutine and request/response implementation
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-Licence-Identifier: AGPL-3.0-or-later
 */

#include <stdint.h> /* for uint8_t */
#include <string.h> /* for memset() */
#include <stdio.h>  /* for fprintf(), stderr */
#include <assert.h>
#include <setjmp.h>

#include "cfg_limits.h"
#include "coroutine.h"

enum coroutine_state {
	CR_ZERO = 0,
	CR_GARBAGE,
	CR_INITIALIZING,
	CR_RUNNING,
	CR_RUNNABLE,
	CR_PAUSED,
};

struct coroutine {
	enum coroutine_state state;
	jmp_buf              env;
	uint8_t              stack[COROUTINE_STACK_SIZE];
};

static struct coroutine coroutine_table[COROUTINE_NUM] = {0};
static cid_t            coroutine_running              = 0;
static jmp_buf          scheduler_env;

void coroutine_init(void) {}

void *__sp;
#define SP ({ asm("movq %%rsp,%0" : "=r"(__sp)); __sp; })

static void print_addr(char *name, void *addr) {
	if ((void*)&coroutine_table <= addr && addr < ((void*)&coroutine_table) + sizeof(coroutine_table)) {
		for (size_t i = 0; i < COROUTINE_NUM; i++) {
			if ((void*)&coroutine_table[i] <= addr && addr < ((void*)&coroutine_table[i]) + sizeof(coroutine_table[i])) {
				printf("%s=%p (in coroutine_table[%zu] [%p, %p])\n", name, addr, i,
				       (void*)&coroutine_table[i], ((void*)&coroutine_table[i]) + sizeof(coroutine_table[i]));
				break;
			}
		}
	} else {
		printf("%s=%p (outside coroutine_table [%p, %p])\n", name, addr,
		       (void*)&coroutine_table, ((void*)&coroutine_table) + sizeof(coroutine_table));
	}
}

static void call_with_stack(void *stack, cr_fn_t fn, void *args) {
	static void  *scheduler_sp = NULL;

	/* I don't intend to run nThis only really exists for running on ARM-32, but being
	 * able to run it on x86-64 is useful for debugging.  */
#if __x86_64__
#define STACK_GROWS_DOWNWARD 1
	print_addr("sp", SP);
	asm volatile ("movq %%rsp , %0\n\t"    /* scheduler_sp = sp */
	              "movq %1    , %%rsp\n\t" /* sp = stack */
	              "movq %3    , %%rdi\n\t" /* arg0 = args */
	              "call *%2\n\t"           /* fn() */
	              "movq %0    , %%rsp"     /* sp = scheduler_sp */
	              :
	              : /* %0 */"m"(scheduler_sp),
	                /* %1 */"r"(stack),
	                /* %2 */"r"(fn),
	                /* %3 */"r"(args)
	              :
	              );
#elif __arm__
#define STACK_GROWS_DOWNWARD 1
	asm volatile ("mov r0, coroutine_table[_cur_cid].arg"
	              "mov _saved_stack sp"
	              "mov sp, coroutine_table[_cur_cid].stack"
	              "bl  coroutine_table[_cur_cid].fn"
	              "mov _saved_stack sp");
#else
#error unsupported architecture
#endif
}

cid_t coroutine_add(cr_fn_t fn, void *args) {
	static cid_t last_created = 0;

	cid_t child;
	{
		size_t idx_base = last_created;
		for (size_t idx_shift = 0; idx_shift < COROUTINE_NUM; idx_shift++) {
			child = ((idx_base + idx_shift) % COROUTINE_NUM) + 1;
			if (coroutine_table[child-1].state == CR_ZERO || coroutine_table[child-1].state == CR_GARBAGE)
				goto found;
		}
		return 0;
	found:
	}

	if (coroutine_table[child-1].state == CR_GARBAGE)
		memset(&coroutine_table[child-1], 0, sizeof(struct coroutine));
	last_created = child;

	cid_t parent = coroutine_running;
	assert(parent == 0 || coroutine_table[parent-1].state == CR_RUNNING);
	coroutine_running = child; 
	coroutine_table[child-1].state = CR_INITIALIZING;
	if (child == 2) {
		print_addr("&child", &child);
	}
	if (!setjmp(scheduler_env)) {
		call_with_stack(coroutine_table[child-1].stack + (STACK_GROWS_DOWNWARD ? COROUTINE_STACK_SIZE : 0), fn, args);
		assert(coroutine_table[child-1].state == CR_INITIALIZING);
		coroutine_table[child-1].state = CR_RUNNABLE;
	}
	assert(coroutine_table[child-1].state == CR_RUNNABLE);
	coroutine_running = parent;

	return child;
}

void coroutine_main(void) {
	bool ran;
	for (coroutine_running = 1;; coroutine_running = (coroutine_running%COROUTINE_NUM)+1) {
		if (coroutine_running == 1)
			ran = false;
		if (coroutine_table[coroutine_running-1].state == CR_RUNNABLE) {
			ran = true;
			coroutine_table[coroutine_running-1].state = CR_RUNNING;
			if (!setjmp(scheduler_env)) {
				longjmp(coroutine_table[coroutine_running-1].env, 1);
				coroutine_table[coroutine_running-1].state = CR_GARBAGE;
			}
		}
		if (coroutine_running == COROUTINE_NUM+1 && !ran) {
			fprintf(stderr, "error: no runnable coroutines\n");
			return;
		}
	}
}

bool _cr_begin(void) {
	assert(coroutine_table[coroutine_running-1].state == CR_INITIALIZING);
	return setjmp(coroutine_table[coroutine_running-1].env);
}

static inline void _cr_transition(enum coroutine_state state) {
	assert(coroutine_table[coroutine_running-1].state == CR_RUNNING);
	coroutine_table[coroutine_running-1].state = state;
	if (state == CR_GARBAGE || !setjmp(coroutine_table[coroutine_running-1].env))
		longjmp(scheduler_env, 1);
}

void cr_exit(void)            { _cr_transition(CR_GARBAGE); }
void cr_yield(void)           { _cr_transition(CR_RUNNABLE); }
void cr_pause_and_yield(void) { _cr_transition(CR_PAUSED); }

void cr_unpause(cid_t cid) {
	assert(coroutine_table[cid-1].state == CR_PAUSED);
	coroutine_table[cid-1].state = CR_RUNNABLE;
}

cid_t cr_getcid(void) {
	return coroutine_running;
}