/* libhw/host_alarmclock.c - <libhw/generic/alarmclock.h> implementation for POSIX hosts
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include <errno.h>
#include <error.h>
#include <signal.h>
#include <time.h>

#include <libcr/coroutine.h>
#include <libmisc/assert.h>
#include <libmisc/vcall.h>

#define IMPLEMENTATION_FOR_LIBHW_GENERIC_ALARMCLOCK_H YES
#include <libhw/generic/alarmclock.h>

#include "host_util.h" /* for host_sigrt_alloc(), ns_to_host_ns_time() */

/* Types **********************************************************************/

struct hostclock {
	implements_alarmclock;
	bool                             initialized;
	clockid_t                        clock_id;
	timer_t                          timer_id;
	struct alarmclock_trigger       *queue;
};

/* Globals ********************************************************************/

static uint64_t hostclock_get_time_ns(implements_alarmclock *self);
static bool     hostclock_add_trigger(implements_alarmclock *self,
                                      struct alarmclock_trigger *trigger,
                                      uint64_t   fire_at_ns,
                                      void       (*cb)(void *),
                                      void      *cb_arg);
static void     hostclock_del_trigger(implements_alarmclock *self,
                                      struct alarmclock_trigger *trigger);

static struct alarmclock_vtable hostclock_vtable = {
	.get_time_ns = hostclock_get_time_ns,
	.add_trigger = hostclock_add_trigger,
	.del_trigger = hostclock_del_trigger,
};

static struct hostclock clock_monotonic = {
	.vtable   = &hostclock_vtable,
	.clock_id = CLOCK_MONOTONIC,
};

implements_alarmclock *bootclock = &clock_monotonic;

/* Main implementation ********************************************************/

#define UNUSED(name) /* name __attribute__ ((unused)) */

static uint64_t hostclock_get_time_ns(implements_alarmclock *_alarmclock) {
	struct hostclock *alarmclock =
		VCALL_SELF(struct hostclock, implements_alarmclock, _alarmclock);
	assert(alarmclock);

	struct timespec ts;

	if (clock_gettime(alarmclock->clock_id, &ts) != 0)
		error(1, errno, "clock_gettime(%d)", (int)alarmclock->clock_id);

	return ns_from_host_ns_time(ts);
}

static void hostclock_handle_sig_alarm(int UNUSED(sig), siginfo_t *info, void *UNUSED(ucontext)) {
	struct hostclock *alarmclock = info->si_value.sival_ptr;
	assert(alarmclock);

	while (alarmclock->queue &&
	       alarmclock->queue->fire_at_ns <= hostclock_get_time_ns(alarmclock)) {
		struct alarmclock_trigger *trigger = alarmclock->queue;
		trigger->cb(trigger->cb_arg);
		alarmclock->queue = trigger->next;
		trigger->alarmclock = NULL;
		trigger->next = NULL;
		trigger->prev = NULL;
	}

	if (alarmclock->queue) {
		struct itimerspec alarmspec = {
			.it_value    = ns_to_host_ns_time(alarmclock->queue->fire_at_ns),
			.it_interval = {0},
		};
		if (timer_settime(alarmclock->timer_id, TIMER_ABSTIME, &alarmspec, NULL) != 0)
			error(1, errno, "timer_settime");
	}
}

static bool hostclock_add_trigger(implements_alarmclock *_alarmclock,
                                      struct alarmclock_trigger *trigger,
                                      uint64_t   fire_at_ns,
                                      void       (*cb)(void *),
                                      void      *cb_arg) {
	struct hostclock *alarmclock =
		VCALL_SELF(struct hostclock, implements_alarmclock, _alarmclock);
	assert(alarmclock);
	assert(trigger);
	assert(fire_at_ns);
	assert(cb);

	trigger->alarmclock = alarmclock;
	trigger->fire_at_ns = fire_at_ns;
	trigger->cb         = cb;
	trigger->cb_arg     = cb_arg;

	bool saved = cr_save_and_disable_interrupts();
	struct alarmclock_trigger **dst = &alarmclock->queue;
	while (*dst && fire_at_ns >= (*dst)->fire_at_ns)
		dst = &(*dst)->next;
	trigger->next = *dst;
	trigger->prev = *dst ? (*dst)->prev : NULL;
	if (*dst)
		(*dst)->prev = trigger;
	*dst = trigger;
	if (!alarmclock->initialized) {
		struct sigevent how_to_notify = {
			.sigev_notify = SIGEV_SIGNAL,
			.sigev_signo  = host_sigrt_alloc(),
			.sigev_value  = {
				.sival_ptr = alarmclock,
			},
		};
		struct sigaction action = {
			.sa_flags = SA_SIGINFO,
			.sa_sigaction = hostclock_handle_sig_alarm,
		};
		if (sigaction(how_to_notify.sigev_signo, &action, NULL) != 0)
			error(1, errno, "sigaction");
		if (timer_create(alarmclock->clock_id, &how_to_notify, &alarmclock->timer_id) != 0)
			error(1, errno, "timer_create(%d)", (int)alarmclock->clock_id);
		alarmclock->initialized = true;
	}
	if (alarmclock->queue == trigger) {
		struct itimerspec alarmspec = {
			.it_value    = ns_to_host_ns_time(trigger->fire_at_ns),
			.it_interval = {0},
		};
		if (timer_settime(alarmclock->timer_id, TIMER_ABSTIME, &alarmspec, NULL) != 0)
			error(1, errno, "timer_settime");
	}
	cr_restore_interrupts(saved);

	return false;
}

static void hostclock_del_trigger(implements_alarmclock *_alarmclock,
                                      struct alarmclock_trigger *trigger) {
	struct hostclock *alarmclock =
		VCALL_SELF(struct hostclock, implements_alarmclock, _alarmclock);

	assert(alarmclock);
	assert(trigger);

	bool saved = cr_save_and_disable_interrupts();
	if (trigger->alarmclock == alarmclock) {
		if (!trigger->prev)
			alarmclock->queue = trigger->next;
		else
			trigger->prev->next = trigger->next;
		if (trigger->next)
			trigger->next->prev = trigger->prev;
		trigger->alarmclock = NULL;
		trigger->prev = NULL;
		trigger->next = NULL;
	}
	cr_restore_interrupts(saved);
}