/* lib9p/tables.c - Access tables of version and message information
 *
 * Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#include <string.h>

#include <libmisc/endian.h>
#include <libmisc/log.h> /* for const_byte_str() */

#include "tables.h"

/* bounds checks **************************************************************/

static inline void assert_ver(enum lib9p_version ver) {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtype-limits"
	assert(0 <= ver && ver < LIB9P_VER_NUM);
#pragma GCC diagnostic pop
}

static inline void assert_typ(enum lib9p_msg_type typ) {
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtype-limits"
	assert(0 <= typ && typ < 0xFF);
#pragma GCC diagnostic pop
}

/* simple lookups *************************************************************/

const char *lib9p_version_str(enum lib9p_version ver) {
	assert_ver(ver);
	return _lib9p_table_ver[ver].name;
}

uint32_t lib9p_version_min_msg_size(enum lib9p_version ver) {
	assert_ver(ver);
	return _lib9p_table_ver[ver].min_msg_size;
}

const char *lib9p_msgtype_str(enum lib9p_version ver, enum lib9p_msg_type typ) {
	assert_ver(ver);
	assert_typ(typ);
	return _lib9p_table_msg[ver][typ].name ?: const_byte_str(typ);
}

/* main message functions *****************************************************/

static
ssize_t _lib9p_validate(uint8_t                          xxx_low_typ_bit,
                        const char                      *xxx_errmsg,
                        const struct _lib9p_recv_tentry  xxx_table[LIB9P_VER_NUM][0x80],
                        struct lib9p_ctx *ctx, uint8_t *net_bytes) {
	assert_ver(ctx->version);
	/* Inspect the first 5 bytes ourselves.  */
	uint32_t net_size = uint32le_decode(net_bytes);
	if (net_size < 5)
		return lib9p_error(ctx, LINUX_EBADMSG, "message is impossibly short");
	uint8_t typ = net_bytes[4];
	if (typ % 2 != xxx_low_typ_bit)
		return lib9p_errorf(ctx, LINUX_EOPNOTSUPP, "%s: message_type=%s", xxx_errmsg,
		                    lib9p_msgtype_str(ctx->version, typ));
	struct _lib9p_recv_tentry tentry = xxx_table[ctx->version][typ/2];
	if (!tentry.validate)
		return lib9p_errorf(ctx, LINUX_EOPNOTSUPP, "unknown message type: %s (protocol_version=%s)",
		                    lib9p_msgtype_str(ctx->version, typ), lib9p_version_str(ctx->version));

	/* Now use the message-type-specific tentry to process the whole thing.  */
	return tentry.validate(ctx, net_size, net_bytes);
}

ssize_t lib9p_Tmsg_validate(struct lib9p_ctx *ctx, uint8_t *net_bytes) {
	return _lib9p_validate(0, "expected a T-message but got an R-message", _lib9p_table_Tmsg_recv,
	                       ctx, net_bytes);
}

ssize_t lib9p_Rmsg_validate(struct lib9p_ctx *ctx, uint8_t *net_bytes) {
	return _lib9p_validate(1, "expected an R-message but got a T-message", _lib9p_table_Rmsg_recv,
	                       ctx, net_bytes);
}

static
void _lib9p_unmarshal(const struct _lib9p_recv_tentry xxx_table[LIB9P_VER_NUM][0x80],
                      struct lib9p_ctx *ctx, uint8_t *net_bytes,
                      enum lib9p_msg_type *ret_typ, void *ret_body) {
	assert_ver(ctx->version);
	enum lib9p_msg_type typ = net_bytes[4];
	*ret_typ = typ;
	struct _lib9p_recv_tentry tentry = xxx_table[ctx->version][typ/2];

	tentry.unmarshal(ctx, net_bytes, ret_body);
}

void lib9p_Tmsg_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes,
                          enum lib9p_msg_type *ret_typ, void *ret_body) {
	_lib9p_unmarshal(_lib9p_table_Tmsg_recv,
	                 ctx, net_bytes, ret_typ, ret_body);
}

void lib9p_Rmsg_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes,
                          enum lib9p_msg_type *ret_typ, void *ret_body) {
	_lib9p_unmarshal(_lib9p_table_Rmsg_recv,
	                 ctx, net_bytes, ret_typ, ret_body);
}

static
bool _lib9p_marshal(const struct _lib9p_send_tentry xxx_table[LIB9P_VER_NUM][0x80],
                    struct lib9p_ctx *ctx, enum lib9p_msg_type typ, void *body,
                    size_t *ret_iov_cnt, struct iovec *ret_iov, uint8_t *ret_copied) {
	assert_ver(ctx->version);
	assert_typ(typ);
	struct _marshal_ret ret = {
		.net_iov_cnt     = 1,
		.net_iov         = ret_iov,
		.net_copied_size = 0,
		.net_copied      = ret_copied,
	};

	struct _lib9p_send_tentry tentry = xxx_table[ctx->version][typ/2];
	bool ret_erred = tentry.marshal(ctx, body, &ret);
	if (ret_iov[ret.net_iov_cnt-1].iov_len == 0)
		ret.net_iov_cnt--;
	*ret_iov_cnt = ret.net_iov_cnt;
	return ret_erred;
}

bool lib9p_Tmsg_marshal(struct lib9p_ctx *ctx, enum lib9p_msg_type typ, void *body,
                        struct lib9p_Tmsg_send_buf *ret) {
	assert(typ % 2 == 0);
	memset(ret, 0, sizeof(*ret));
	return _lib9p_marshal(_lib9p_table_Tmsg_send,
	                      ctx, typ, body,
	                      &ret->iov_cnt, ret->iov, ret->copied);
}

bool lib9p_Rmsg_marshal(struct lib9p_ctx *ctx, enum lib9p_msg_type typ, void *body,
                        struct lib9p_Rmsg_send_buf *ret) {
	assert(typ % 2 == 1);
	memset(ret, 0, sizeof(*ret));
	return _lib9p_marshal(_lib9p_table_Rmsg_send,
	                      ctx, typ, body,
	                      &ret->iov_cnt, ret->iov, ret->copied);
}

/* `struct lib9p_stat` helpers ************************************************/

bool lib9p_stat_validate(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes,
                         uint32_t *ret_net_size, ssize_t *ret_host_size) {
	ssize_t host_size = _lib9p_stat_validate(ctx, net_size, net_bytes, ret_net_size);
	if (host_size < 0)
		return true;
	if (ret_host_size)
		*ret_host_size = host_size;
	return false;
}

void lib9p_stat_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes,
                          struct lib9p_stat *ret) {
	_lib9p_stat_unmarshal(ctx, net_bytes, ret);
}

uint32_t lib9p_stat_marshal(struct lib9p_ctx *ctx, uint32_t max_net_size, struct lib9p_stat *obj,
                            uint8_t *ret_bytes) {
	struct lib9p_ctx _ctx = *ctx;
	_ctx.max_msg_size = max_net_size;

	struct iovec iov = {0};
	struct _marshal_ret ret = {
		.net_iov_cnt     = 1,
		.net_iov         = &iov,
		.net_copied_size = 0,
		.net_copied      = ret_bytes,
	};
	if (_lib9p_stat_marshal(&_ctx, obj, &ret))
		return 0;
	return ret.net_iov[0].iov_len;
}