/* lib9p/tests/test_server/main.c - Main entry point for test 9P server
 *
 * Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include <error.h>

#include <lib9p/srv.h>
#include <libcr/coroutine.h>
#include <libhw/generic/net.h>
#include <libhw/generic/alarmclock.h>
#include <libhw/host_alarmclock.h>
#include <libhw/host_net.h>
#include <libmisc/macro.h>
#include <util9p/static.h>

#include "static.h"

/* configuration **************************************************************/

#include "config.h"

#ifndef CONFIG_SRV9P_NUM_CONNS
	#error config.h must define CONFIG_SRV9P_NUM_CONNS
#endif

/* globals ********************************************************************/

static lo_interface lib9p_srv_file get_root(struct lib9p_srv_ctx *, struct lib9p_s);

const char *hexdig = "0123456789abcdef";

struct {
	struct hostnet_tcp_listener     listeners[CONFIG_SRV9P_NUM_CONNS];
	struct lib9p_srv                srv;
} globals = {
	.srv = (struct lib9p_srv){
		.rootdir = get_root,
	},
};

/* api ************************************************************************/

struct api_file {
	uint64_t pathnum;
};
LO_IMPLEMENTATION_H(lib9p_srv_file, struct api_file, api)
LO_IMPLEMENTATION_C(lib9p_srv_file, struct api_file, api, static)

static void api_free(struct api_file *self) {
	assert(self);
}
static struct lib9p_qid api_qid(struct api_file *self) {
	assert(self);
	return (struct lib9p_qid){
		.type = LIB9P_QT_FILE,
		.vers = 1,
		.path = self->pathnum,
	};
}
static uint32_t api_chio(struct api_file *self, struct lib9p_srv_ctx *ctx, bool, bool, bool) {
	assert(self);
	assert(ctx);
	return 0;
}

static struct lib9p_stat api_stat(struct api_file *self, struct lib9p_srv_ctx *ctx) {
	assert(self);
	assert(ctx);
	return (struct lib9p_stat){
		.kern_type                = 0,
		.kern_dev                 = 0,
		.file_qid                 = api_qid(self),
		.file_mode                = 0222,
		.file_atime               = UTIL9P_ATIME,
		.file_mtime               = UTIL9P_MTIME,
		.file_size                = 0,
		.file_name                = lib9p_str("shutdown"),
		.file_owner_uid           = lib9p_str("root"),
		.file_owner_gid           = lib9p_str("root"),
		.file_last_modified_uid   = lib9p_str("root"),
		.file_extension           = lib9p_str(NULL),
		.file_owner_n_uid         = 0,
		.file_owner_n_gid         = 0,
		.file_last_modified_n_uid = 0,
	};
}
static void api_wstat(struct api_file *self, struct lib9p_srv_ctx *ctx, struct lib9p_stat) {
	assert(self);
	assert(ctx);
	lib9p_error(&ctx->basectx, LINUX_EROFS, "cannot wstat API file");
}
static void api_remove(struct api_file *self, struct lib9p_srv_ctx *ctx) {
	assert(self);
	assert(ctx);
	lib9p_error(&ctx->basectx, LINUX_EROFS, "cannot remove API file");
}

LIB9P_SRV_NOTDIR(struct api_file, api)

static uint32_t api_pwrite(struct api_file *self, struct lib9p_srv_ctx *ctx, void *buf, uint32_t byte_count, uint64_t LM_UNUSED(offset)) {
	assert(self);
	assert(ctx);
	assert(buf);
	if (byte_count == 0)
		return 0;
	for (int i = 0; i < CONFIG_SRV9P_NUM_CONNS; i++)
		LO_CALL(lo_box_hostnet_tcplist_as_net_stream_listener(&globals.listeners[i]), close);
	return byte_count;
}
static uint32_t api_pread(struct api_file *, struct lib9p_srv_ctx *, void *, uint32_t, uint64_t) {
	assert_notreached("not readable");
}

#define lo_box_api_as_lib9p_srv_file(obj) util9p_box(api, obj)

/* file tree ******************************************************************/

enum { PATH_BASE = __COUNTER__ };
#define PATH_COUNTER __COUNTER__ - PATH_BASE

#define STATIC_FILE(STRNAME, SYMNAME)                                      \
	UTIL9P_STATIC_FILE(PATH_COUNTER, STRNAME,                          \
	                   .data_start = _binary_static_##SYMNAME##_start, \
	                   .data_end   = _binary_static_##SYMNAME##_end)
#define STATIC_DIR(STRNAME, ...)  \
	UTIL9P_STATIC_DIR(PATH_COUNTER, STRNAME, __VA_ARGS__)

struct lib9p_srv_file root =
		STATIC_DIR("",
		           STATIC_DIR("Documentation",
		                      STATIC_FILE("x", Documentation_x),
		                      ),
		           STATIC_FILE("README.md", README_md),
		           lo_box_api_as_lib9p_srv_file(&(struct api_file){.pathnum = PATH_COUNTER}),
		           );

static lo_interface lib9p_srv_file get_root(struct lib9p_srv_ctx *LM_UNUSED(ctx), struct lib9p_s LM_UNUSED(treename)) {
	return root;
}

/* main ***********************************************************************/

static COROUTINE read_cr(void *_i) {
	int i = *((int *)_i);
	cr_begin();

	hostnet_tcp_listener_init(&globals.listeners[i], 9000);

	lib9p_srv_read_cr(&globals.srv, lo_box_hostnet_tcplist_as_net_stream_listener(&globals.listeners[i]));

	cr_end();
}

static COROUTINE init_cr(void *) {
	cr_begin();

	sleep_for_ms(1);

	for (int i = 0; i < CONFIG_SRV9P_NUM_CONNS; i++) {
		char name[] = {'r', 'e', 'a', 'd', '-', hexdig[i], '\0'};
		if (!coroutine_add(name, read_cr, &i))
			error(1, 0, "coroutine_add(read_cr, &i)");
	}
	for (int i = 0; i < 2*CONFIG_SRV9P_NUM_CONNS; i++) {
		char name[] = {'w', 'r', 'i', 't', 'e', '-', hexdig[i], '\0'};
		if (!coroutine_add(name, lib9p_srv_write_cr, &globals.srv))
			error(1, 0, "coroutine_add(lib9p_srv_write_cr, &globals.srv)");
	}

	cr_exit();
}

int main() {
	struct hostclock clock_monotonic = {
		.clock_id = CLOCK_MONOTONIC,
	};
	bootclock = lo_box_hostclock_as_alarmclock(&clock_monotonic);
	coroutine_add("init", init_cr, NULL);
	coroutine_main();
	return 0;
}