/* libhw/rp2040_hwtimer.c - <libhw/generic/alarmclock.h> implementation for the RP2040's hardware timer
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include <hardware/irq.h>   /* pico-sdk:hardware_irq */
#include <hardware/timer.h> /* pico-sdk:hardware_timer */

#include <libcr/coroutine.h>
#include <libmisc/assert.h>
#include <libmisc/vcall.h>

#define IMPLEMENTATION_FOR_LIBHW_GENERIC_ALARMCLOCK_H YES
#include <libhw/generic/alarmclock.h>
#include <libhw/rp2040_hwtimer.h>

/******************************************************************************/

/** Conflict with pico-sdk:pico_time:!PICO_TIME_DEFAULT_ALARM_POOL_DISABLED.  */
void add_alarm_at(void) {};

/* Types **********************************************************************/

struct rp2040_hwtimer {
	implements_alarmclock;
	enum rp2040_hwalarm_instance     alarm_num;
	bool                             initialized;
	struct alarmclock_trigger       *queue;
};

/* Globals ********************************************************************/

static uint64_t rp2040_hwtimer_get_time_ns(implements_alarmclock *self);
static bool     rp2040_hwtimer_add_trigger(implements_alarmclock *self,
                                           struct alarmclock_trigger *trigger,
                                           uint64_t   fire_at_ns,
                                           void       (*cb)(void *),
                                           void      *cb_arg);
static void     rp2040_hwtimer_del_trigger(implements_alarmclock *self,
                                           struct alarmclock_trigger *trigger);

static struct alarmclock_vtable rp2040_hwtimer_vtable = {
	.get_time_ns = rp2040_hwtimer_get_time_ns,
	.add_trigger = rp2040_hwtimer_add_trigger,
	.del_trigger = rp2040_hwtimer_del_trigger,
};

static struct rp2040_hwtimer hwtimers[] = {
	{ .vtable = &rp2040_hwtimer_vtable, .alarm_num = 0 },
	{ .vtable = &rp2040_hwtimer_vtable, .alarm_num = 1 },
	{ .vtable = &rp2040_hwtimer_vtable, .alarm_num = 2 },
	{ .vtable = &rp2040_hwtimer_vtable, .alarm_num = 3 },
};
static_assert(sizeof(hwtimers)/sizeof(hwtimers[0]) == _RP2040_HWALARM_NUM);

implements_alarmclock *bootclock = &hwtimers[0];

/* Main implementation ********************************************************/

implements_alarmclock *rp2040_hwtimer(enum rp2040_hwalarm_instance alarm_num) {
	assert(alarm_num < _RP2040_HWALARM_NUM);
	return &hwtimers[alarm_num];
}


static uint64_t rp2040_hwtimer_get_time_ns(implements_alarmclock *) {
	return timer_time_us_64(timer_hw) * (NS_PER_S/US_PER_S);
}

#define NS_TO_US_ROUNDUP(x) LM_CEILDIV(x, NS_PER_S/US_PER_S)

static void rp2040_hwtimer_intrhandler(void) {
	uint irq_num = __get_current_exception() - VTABLE_FIRST_IRQ;
	enum rp2040_hwalarm_instance alarm_num = TIMER_ALARM_NUM_FROM_IRQ(irq_num);
	assert(alarm_num < _RP2040_HWALARM_NUM);

	struct rp2040_hwtimer *alarmclock = &hwtimers[alarm_num];

	while (alarmclock->queue &&
	       NS_TO_US_ROUNDUP(alarmclock->queue->fire_at_ns) <= timer_time_us_64(timer_hw)) {
		struct alarmclock_trigger *trigger = alarmclock->queue;
		trigger->cb(trigger->cb_arg);
		alarmclock->queue = trigger->next;
		trigger->alarmclock = NULL;
		trigger->next = NULL;
		trigger->prev = NULL;
	}

	hw_clear_bits(&timer_hw->intf, 1 << alarm_num); /* Clear "force"ing the interrupt.  */
	hw_clear_bits(&timer_hw->intr, 1 << alarm_num); /* Clear natural firing of the alarm.  */
	if (alarmclock->queue)
		timer_hw->alarm[alarm_num] = (uint32_t)NS_TO_US_ROUNDUP(alarmclock->queue->fire_at_ns);
}

static bool rp2040_hwtimer_add_trigger(implements_alarmclock     *_alarmclock,
                                       struct alarmclock_trigger *trigger,
                                       uint64_t                   fire_at_ns,
                                       void                       (*cb)(void *),
                                       void                      *cb_arg) {
	struct rp2040_hwtimer *alarmclock =
		VCALL_SELF(struct rp2040_hwtimer, implements_alarmclock, _alarmclock);
	assert(alarmclock);
	assert(trigger);
	assert(fire_at_ns);
	assert(cb);

	uint64_t now_us = timer_time_us_64(timer_hw);
	if (NS_TO_US_ROUNDUP(fire_at_ns) > now_us &&
	    (NS_TO_US_ROUNDUP(fire_at_ns) - now_us) > UINT32_MAX)
		/* Too far in the future.  */
		return true;

	trigger->alarmclock = alarmclock;
	trigger->fire_at_ns = fire_at_ns;
	trigger->cb         = cb;
	trigger->cb_arg     = cb_arg;

	bool saved = cr_save_and_disable_interrupts();
	struct alarmclock_trigger **dst = &alarmclock->queue;
	while (*dst && fire_at_ns >= (*dst)->fire_at_ns)
		dst = &(*dst)->next;
	trigger->next = *dst;
	trigger->prev = *dst ? (*dst)->prev : NULL;
	if (*dst)
		(*dst)->prev = trigger;
	*dst = trigger;
	if (!alarmclock->initialized) {
		hw_set_bits(&timer_hw->inte, 1 << alarmclock->alarm_num);
		irq_set_exclusive_handler(TIMER_ALARM_IRQ_NUM(timer_hw, alarmclock->alarm_num),
		                          rp2040_hwtimer_intrhandler);
		irq_set_enabled(TIMER_ALARM_IRQ_NUM(timer_hw, alarmclock->alarm_num), true);
		alarmclock->initialized = true;
	}
	if (alarmclock->queue == trigger) {
		/* "Force" the interrupt handler to trigger as soon as
		 * we enable interrupts.  This handles the case of
		 * when fire_at_ns is before when we called
		 * cr_save_and_disable_interrupts().  We could check
		 * timer_time_us_64() again after calling
		 * cr_save_and_disable_interrupts() and do this
		 * conditionally, but I don't think that would be any
		 * more efficient than just letting the interrupt
		 * fire.  */
		hw_set_bits(&timer_hw->intf, 1 << alarmclock->alarm_num);
	}
	cr_restore_interrupts(saved);

	return false;
}

static void rp2040_hwtimer_del_trigger(implements_alarmclock     *_alarmclock,
                                       struct alarmclock_trigger *trigger) {
	struct rp2040_hwtimer *alarmclock =
		VCALL_SELF(struct rp2040_hwtimer, implements_alarmclock, _alarmclock);
	assert(alarmclock);
	assert(trigger);

	bool saved = cr_save_and_disable_interrupts();
	if (trigger->alarmclock == alarmclock) {
		if (!trigger->prev)
			alarmclock->queue = trigger->next;
		else
			trigger->prev->next = trigger->next;
		if (trigger->next)
			trigger->next->prev = trigger->prev;
		trigger->alarmclock = NULL;
		trigger->prev = NULL;
		trigger->next = NULL;
	} else
		assert(!trigger->alarmclock);
	cr_restore_interrupts(saved);
}