/* sbc_harness/fs_harness_flash_bin.c - 9P access to flash storage
 *
 * Copyright (C) 2025  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include <hardware/flash.h>
#include <hardware/watchdog.h>

#define LOG_NAME FLASH
#include <libmisc/log.h>

#include <util9p/static.h>

#define IMPLEMENTATION_FOR_FS_HARNESS_FLASH_BIN YES
#include "fs_harness_flash_bin.h"

LO_IMPLEMENTATION_C(lib9p_srv_file, struct flash_file, flash_file, static);

LO_IMPLEMENTATION_H(lib9p_srv_fio, struct flash_file, flash_file);
LO_IMPLEMENTATION_C(lib9p_srv_fio, struct flash_file, flash_file, static);

#define DATA_START ((const char *)(XIP_NOALLOC_BASE))
#define DATA_SIZE  PICO_FLASH_SIZE_BYTES
#define DATA_HSIZE (DATA_SIZE/2)
static_assert(DATA_SIZE % FLASH_SECTOR_SIZE == 0);
static_assert(DATA_HSIZE % FLASH_SECTOR_SIZE == 0);

/* There are some memcpy()s (and memcmp()s?) in here that can (and
 * arguably should) be replaced with SSI DMA.  */

/* ab_flash_* (mid-level utilities for our A/B write scheme) ******************/

/**
 * Copy the upper half of flash to the lower half of flash, then reboot.
 *
 * @param buf : a scratch buffer that is at least FLASH_SECTOR_SIZE
 */
[[noreturn]] static void __no_inline_not_in_flash_func(ab_flash_finalize)(uint8_t *buf) {
	assert(buf);

	infof("copying upper flash to lower flash...");

	cr_save_and_disable_interrupts();

	for (size_t off = 0; off < DATA_HSIZE; off += FLASH_SECTOR_SIZE) {
		memcpy(buf, DATA_START+DATA_HSIZE+off, FLASH_SECTOR_SIZE);
		if (memcmp(DATA_START+off, buf, FLASH_SECTOR_SIZE) == 0)
			continue;
		flash_range_erase(off, FLASH_SECTOR_SIZE);
		flash_range_program(off, buf, FLASH_SECTOR_SIZE);
	}

	infof("rebooting...");

	watchdog_reboot(0, 0, 300);

	for (;;)
		asm volatile ("nop");
}

/**
 * Set the upper half of flash to all zero bytes.
 *
 * @param buf : a scratch buffer that is at least FLASH_SECTOR_SIZE
 */
static void ab_flash_initialize_zero(uint8_t *buf) {
	assert(buf);

	memset(buf, 0, FLASH_SECTOR_SIZE);

	infof("zeroing upper flash...");
	for (size_t off = DATA_HSIZE; off < DATA_SIZE; off += FLASH_SECTOR_SIZE) {
		if (memcmp(buf, DATA_START+off, FLASH_SECTOR_SIZE) == 0)
			continue;
		bool saved = cr_save_and_disable_interrupts();
		/* No need to `flash_range_erase()`; the way the flash
		 * works is that _erase() sets all bits to 1, and
		 * _program() sets some bits to 0.  If we don't need
		 * any bits to be 1, then we can skip the
		 * _erase().  */
		flash_range_program(off, buf, FLASH_SECTOR_SIZE);
		cr_restore_interrupts(saved);
	}
	debugf("... zeroed");
}

/**
 * Copy the lower half of flash to the upper half of flash.
 *
 * @param buf : a scratch buffer that is at least FLASH_SECTOR_SIZE
 */
static void ab_flash_initialize(uint8_t *buf) {
	assert(buf);

	infof("initializing upper flash...");
	for (size_t off = 0; off < DATA_HSIZE; off += FLASH_SECTOR_SIZE) {
		memcpy(buf, DATA_START+off, FLASH_SECTOR_SIZE);
		if (memcmp(buf, DATA_START+DATA_HSIZE+off, FLASH_SECTOR_SIZE) == 0)
			continue;
		bool saved = cr_save_and_disable_interrupts();
		flash_range_erase(DATA_HSIZE+off, FLASH_SECTOR_SIZE);
		flash_range_program(DATA_HSIZE+off, buf, FLASH_SECTOR_SIZE);
		cr_restore_interrupts(saved);
	}
	debugf("... initialized");
}

/**
 * Write `dat` to flash sector `pos`+(DATA_SIZE/2) (i.e. `pos` is a
 * sector in the lower half, but this function writes to the upper
 * half).
 *
 * @param pos : start-position of the sector to write to, must be in the upper half of the flash
 * @param dat : the FLASH_SECTOR_SIZE bytes to write
 */
static void ab_flash_write_sector(size_t pos, uint8_t *dat) {
	assert(pos < DATA_HSIZE);
	assert(pos % FLASH_SECTOR_SIZE == 0);
	assert(dat);

	pos += DATA_HSIZE;

	infof("write flash sector @ %zu...", pos);
	if (memcmp(dat, DATA_START+pos, FLASH_SECTOR_SIZE) != 0) {
		bool saved = cr_save_and_disable_interrupts();
		flash_range_erase(pos, FLASH_SECTOR_SIZE);
		flash_range_program(pos, dat, FLASH_SECTOR_SIZE);
		cr_restore_interrupts(saved);
	}
	debugf("... written");
}

/* srv_file *******************************************************************/

static void flash_file_free(struct flash_file *self) {
	assert(self);
}
static struct lib9p_qid flash_file_qid(struct flash_file *self) {
	assert(self);

	return (struct lib9p_qid){
		.type = LIB9P_QT_FILE|LIB9P_QT_EXCL,
		.vers = 1,
		.path = self->pathnum,
	};
}

static struct lib9p_stat flash_file_stat(struct flash_file *self, struct lib9p_srv_ctx *ctx) {
	assert(self);
	assert(ctx);

	return (struct lib9p_stat){
		.kern_type                = 0,
		.kern_dev                 = 0,
		.file_qid                 = flash_file_qid(self),
		.file_mode                = LIB9P_DM_EXCL|0666,
		.file_atime               = UTIL9P_ATIME,
		.file_mtime               = UTIL9P_MTIME,
		.file_size                = DATA_SIZE,
		.file_name                = lib9p_str(self->name),
		.file_owner_uid           = lib9p_str("root"),
		.file_owner_gid           = lib9p_str("root"),
		.file_last_modified_uid   = lib9p_str("root"),
		.file_extension           = lib9p_str(NULL),
		.file_owner_n_uid         = 0,
		.file_owner_n_gid         = 0,
		.file_last_modified_n_uid = 0,
	};
}
static void flash_file_wstat(struct flash_file *self, struct lib9p_srv_ctx *ctx,
                                    struct lib9p_stat) {
	assert(self);
	assert(ctx);

	lib9p_error(&ctx->basectx, LINUX_EROFS, "read-only part of filesystem");
}
static void flash_file_remove(struct flash_file *self, struct lib9p_srv_ctx *ctx) {
	assert(self);
	assert(ctx);

	lib9p_error(&ctx->basectx, LINUX_EROFS, "read-only part of filesystem");
}

LIB9P_SRV_NOTDIR(struct flash_file, flash_file);

static lo_interface lib9p_srv_fio flash_file_fopen(struct flash_file *self, struct lib9p_srv_ctx *ctx,
                                                   bool rd, bool wr, bool trunc) {
	assert(self);
	assert(ctx);

	if (rd) {
		self->rbuf.ok = false;
	}

	if (wr) {
		if (trunc) {
			ab_flash_initialize_zero(self->wbuf.dat);
			self->written = true;
		} else {
			ab_flash_initialize(self->wbuf.dat);
			self->written = false;
		}
		self->wbuf.ok = false;
	}

	return lo_box_flash_file_as_lib9p_srv_fio(self);
}

/* srv_fio ********************************************************************/

static uint32_t flash_file_iounit(struct flash_file *self) {
	assert(self);
	return FLASH_SECTOR_SIZE;
}

static void flash_file_iofree(struct flash_file *self) {
	assert(self);

	if (self->wbuf.ok)
		ab_flash_write_sector(self->wbuf.pos, self->wbuf.dat);

	if (self->written)
		ab_flash_finalize(self->wbuf.dat);
}

static void flash_file_pread(struct flash_file *self, struct lib9p_srv_ctx *ctx,
                             uint32_t byte_count, uint64_t byte_offset,
                             struct iovec *ret) {
	assert(self);
	assert(ctx);
	assert(ret);

	if (byte_offset > DATA_SIZE) {
		lib9p_error(&ctx->basectx,
		            LINUX_EINVAL, "offset is past the chip size");
		return;
	}

	/* Assume that somewhere down the line the iovec we return
	 * will be passed to DMA.  We don't want the DMA engine to hit
	 * (slow) XIP (for instance, this can cause reads/writes to
	 * the SSP to get out of sync with eachother), so copy the
	 * data to a buffer in (fast) RAM first.  It's lame that the
	 * DMA engine can only have a DREQ on one side of the channel.
	 */
	if (byte_offset == DATA_SIZE) {
		*ret = (struct iovec){
			.iov_len = 0,
		};
		return;
	}
	size_t sector_base = LM_ROUND_DOWN(byte_offset, FLASH_SECTOR_SIZE);
	if (byte_offset + byte_count > sector_base + FLASH_SECTOR_SIZE)
		byte_count = (sector_base + FLASH_SECTOR_SIZE) - byte_offset;
	assert(byte_offset + byte_count <= DATA_SIZE);

	if (!self->rbuf.ok || self->rbuf.pos != sector_base) {
		self->rbuf.ok = true;
		self->rbuf.pos = sector_base;
		memcpy(self->rbuf.dat, DATA_START+sector_base, FLASH_SECTOR_SIZE);
	}

	*ret = (struct iovec){
		.iov_base = &self->rbuf.dat[byte_offset-sector_base],
		.iov_len = byte_count,
	};
}

/* TODO: Short/corrupt writes are dangerous.  This should either (1)
 * check a checksum, (2) use uf2 instead of verbatim data, or (3) use
 * ihex instead of verbatim data.  */
static uint32_t flash_file_pwrite(struct flash_file *self, struct lib9p_srv_ctx *ctx,
                                  void *buf,
                                  uint32_t byte_count,
                                  uint64_t byte_offset) {
	assert(self);
	assert(ctx);

	if (byte_offset > DATA_HSIZE) {
		lib9p_error(&ctx->basectx,
		            LINUX_EINVAL, "offset is past half the chip size");
		return 0;
	}
	if (byte_count == 0)
		return 0;
	if (byte_offset == DATA_HSIZE) {
		lib9p_error(&ctx->basectx,
		            LINUX_EINVAL, "offset is at half the chip size");
		return 0;
	}

	size_t sector_base = LM_ROUND_DOWN(byte_offset, FLASH_SECTOR_SIZE);
	if (byte_offset + byte_count > sector_base + FLASH_SECTOR_SIZE)
		byte_count = (sector_base + FLASH_SECTOR_SIZE) - byte_offset;
	assert(byte_offset + byte_count < DATA_HSIZE);

	if (self->wbuf.ok && self->wbuf.pos != sector_base)
		ab_flash_write_sector(self->wbuf.pos, self->wbuf.dat);
	if (!self->wbuf.ok || self->wbuf.pos != sector_base) {
		self->wbuf.ok = true;
		self->wbuf.pos = sector_base;
		if (byte_count != FLASH_SECTOR_SIZE)
			memcpy(self->wbuf.dat, DATA_START+DATA_HSIZE+sector_base, FLASH_SECTOR_SIZE);
	}
	memcpy(&self->wbuf.dat[byte_offset-sector_base], buf, byte_count);

	self->written = true;
	return byte_count;
}