/* libpromise/promise.c - Promises for libcr * * Copyright (C) 2024 Luke T. Shumaker * SPDX-Licence-Identifier: AGPL-3.0-or-later */ #include /* for assert() */ #include /* for size_t */ #include #include /* Header *********************************************************************/ struct promise_vtable; struct promise_queue; struct __promise; typedef struct __promise { struct promise_vtable *vtable; /* queue */ struct promise_queue *queue_root; struct __promise *queue_front, *queue_rear; struct { cid_t cid; size_t cnt; struct __promise **vec; } waiter; bool resolved; } implements_promise; struct promise_vtable { void (*on_complete)(implements_promise *); bool (*would_block)(implements_promise *); }; struct promise_queue { implements_promise *front, *rear; }; /* consumer */ void promise_push_to_rear(implements_promise *, struct promise_queue *); void promise_await(implements_promise *); size_t promise_select_v(size_t arg_cnt, implements_promise *arg_vec[]); #define promise_select_l(...) ({ \ implements_promise *args[] = { __VA_ARGS__ }; \ promise_select_v(sizeof(args)/sizeof(args[0]), args); \ }) /* producer */ void promise_resolve(implements_promise *); /* Impl ***********************************************************************/ /* Linked-list operations ==========================================*/ void promise_push_to_rear(implements_promise *item, struct promise_queue *root) { assert(item); assert(root); item->root = root; item->front = root->rear; item->rear = NULL; if (root->rear) root->rear->rear = item; else root->front = item; root->rear = item; } static void promise_remove(implements_promise_queueitem *item) { assert(item); assert(item->root); struct promise_queue *root = item->root; if (item->front) item->front->rear = item->rear; else root->front = item->rear; if (item->rear) item->rear->front = item->front; else root->rear = item->front; } /* Resolve =========================================================*/ static inline void promise_remove_all(size_t arg_cnt, implements_promise *arg_vec[]) { for (size_t i = 0; i < arg_cnt; i++) { if (!arg_vec[i]) continue; promise_remove(arg_vec[i]); arg_vec[i]->waiter = (typeof(arg_vec[i]->waiter)){0}; } } void promise_resolve(implements_promise *arg) { assert(arg); bool enable = cr_disable_interrupts(); assert(!arg->resolved); VCALL(arg, on_complete); arg->resolved = true; if (arg->waiter.cid) { typeof(arg->waiter) waiter = arg->waiter; promise_resolve_all(waiter->cnt, waiter->vec); cr_unpause(waiter->cid); } if (enable) cr_enable_interrupts(); } /* Single-wait =====================================================*/ void promise_await(implements_promise *arg) { assert(arg); bool enable = cr_disable_interrupts(); if (arg->resolved) { /* do nothing */ } else if (!VCALL(arg, would_block)) { promise_resolve(arg); } else { implements_promise *vec[1] = { arg }; arg->waiter = (typeof(arg->waiter)){ .cid = cr_getcid(); .cnt = 1; .vec = vec; }; cr_pause_and_yield(); } if (enable) cr_enable_interrupts(); assert(arg->resolved); } /* Multi-wait ======================================================*/ static inline size_t pickone(size_t cnt) { long fair_cnt = (0x80000000L / cnt) * cnt; long rnd; do { rnd = random(); } while (rnd >= fair_cnt); return rnd % cnt; } size_t promise_select_v(size_t arg_cnt, implements_promise *arg_vec[]) { size_t cnt_blocking = 0; size_t cnt_nonblock = 0; size_t cnt_default = 0; size_t idx_default; bool wouldblock[arg_cnt]; assert(arg_cnt); assert(arg_vec); cr_disable_interrupts(); for (size_t i = 0; i < arg_cnt; i++) { if (arg_vec[i] == NULL) { cnt_default++; idx_default = i; } else if ( ((wouldblock[i] = (!arg_vec[i]->resolved) && VCALL(arg_vec[i], would_block))) ) cnt_blocking++; else cnt_nonblock++; } if (cnt_nonblock) { size_t choice = pickone(cnt_nonblock); for (size_t i = 0, seen = 0; i < arg_cnt; i++) { if (arg_vec[i] && !wouldblock[i]) { if (seen == choice) { promise_resolve(arg_vec[i]); cr_enable_interrupts(); return i; } seen++; } } __builtin_unreachable(); } if (cnt_default) { return resolve_select(arg_cnt, arg_vec, idx_default); for (size_t i = 0, seen = 0; i < arg_cnt; i++) { arg_vec[i]->cid = cr_getcid(); arg_vec[i]->cid = cr_getcid(); } cr_pause_and_yield(); for }