/* lib9p/internal.h - TODO
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#ifndef _LIB9P_INTERNAL_H_
#define _LIB9P_INTERNAL_H_

#include <stddef.h> /* for size_t */
#include <limits.h> /* for SSIZE_MAX */

#include <lib9p/9p.h>

/* configuration **************************************************************/

#include "config.h"

#ifndef CONFIG_9P_MAX_MSG_SIZE
	#error config.h must define CONFIG_9P_MAX_MSG_SIZE
#endif
#ifndef CONFIG_9P_MAX_HOSTMSG_SIZE
	#error config.h must define CONFIG_9P_MAX_HOSTMSG_SIZE
#endif
#ifndef CONFIG_9P_MAX_FIDS
	#error config.h must define CONFIG_9P_MAX_FIDS
#endif
#ifndef CONFIG_9P_MAX_REQS
	#error config.h must define CONFIG_9P_MAX_REQS
#endif
#ifndef CONFIG_9P_MAX_ERR_SIZE
	#error config.h must define CONFIG_9P_MAX_ERR_SIZE
#endif

static_assert(CONFIG_9P_MAX_ERR_SIZE <= UINT16_MAX);
static_assert(CONFIG_9P_MAX_MSG_SIZE <= CONFIG_9P_MAX_HOSTMSG_SIZE);
static_assert(CONFIG_9P_MAX_HOSTMSG_SIZE <= SSIZE_MAX);

/* C language *****************************************************************/

#define UNUSED(name)            /* name __attribute__((unused)) */
#define ALWAYS_INLINE           inline __attribute__((always_inline))
#define FLATTEN                 __attribute__((flatten))
#define ARRAY_LEN(arr)          (sizeof(arr)/sizeof((arr)[0]))
#define CAT2(a, b)              a##b
#define CAT3(a, b, c)           a##b##c

/* specialized contexts *******************************************************/

struct _validate_ctx {
	struct lib9p_ctx        *ctx;
	uint32_t                 net_size;
	uint8_t                 *net_bytes;

	uint32_t                 net_offset;
	/* Increment `host_extra` to pre-allocate space that is
	 * "extra" beyond sizeof().  */
	size_t                   host_extra;
};
typedef bool (*_validate_fn_t)(struct _validate_ctx *ctx);

struct _unmarshal_ctx {
	struct lib9p_ctx        *ctx;
	uint8_t                 *net_bytes;

	uint32_t                 net_offset;
	/* `extra` points to the beginning of unallocated space.  */
	void                    *extra;
};
typedef void (*_unmarshal_fn_t)(struct _unmarshal_ctx *ctx, void *out);

struct _marshal_ctx {
	struct lib9p_ctx        *ctx;

	uint8_t                 *net_bytes;
	uint32_t                 net_offset;
};
typedef bool (*_marshal_fn_t)(struct _marshal_ctx *ctx, void *host_val);

/* tables *********************************************************************/

struct _table_msg {
	char                   *name;
	size_t                  basesize;
	_validate_fn_t          validate;
	_unmarshal_fn_t         unmarshal;
	_marshal_fn_t           marshal;
};

struct _table_version {
	struct _table_msg       msgs[0x100];
};

extern struct _table_version _lib9p_versions[LIB9P_VER_NUM];

bool _lib9p_validate_stat(struct _validate_ctx *ctx);
void _lib9p_unmarshal_stat(struct _unmarshal_ctx *ctx, struct lib9p_stat *out);
bool _lib9p_marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val);

/* unmarshal utilities ********************************************************/

static ALWAYS_INLINE uint8_t decode_u8le(uint8_t *in) {
	return in[0];
}
static ALWAYS_INLINE uint16_t decode_u16le(uint8_t *in) {
	return (((uint16_t)(in[0])) <<  0)
	     | (((uint16_t)(in[1])) <<  8)
	     ;
}
static ALWAYS_INLINE uint32_t decode_u32le(uint8_t *in) {
	return (((uint32_t)(in[0])) <<  0)
	     | (((uint32_t)(in[1])) <<  8)
	     | (((uint32_t)(in[2])) << 16)
	     | (((uint32_t)(in[3])) << 24)
	     ;
}
static ALWAYS_INLINE uint64_t decode_u64le(uint8_t *in) {
	return (((uint64_t)(in[0])) <<  0)
	     | (((uint64_t)(in[1])) <<  8)
	     | (((uint64_t)(in[2])) << 16)
	     | (((uint64_t)(in[3])) << 24)
	     | (((uint64_t)(in[4])) << 32)
	     | (((uint64_t)(in[5])) << 40)
	     | (((uint64_t)(in[6])) << 48)
	     | (((uint64_t)(in[7])) << 56)
	     ;
}

static inline bool _is_valid_utf8(uint8_t *str, size_t len, bool forbid_nul) {
	uint32_t ch;
	uint8_t chlen;
	assert(str);
	for (size_t pos = 0; pos < len;) {
		if      ((str[pos] & 0b10000000) == 0b00000000) { ch = str[pos] & 0b01111111; chlen = 1; }
		else if ((str[pos] & 0b11100000) == 0b11000000) { ch = str[pos] & 0b00011111; chlen = 2; }
		else if ((str[pos] & 0b11110000) == 0b11100000) { ch = str[pos] & 0b00001111; chlen = 3; }
		else if ((str[pos] & 0b11111000) == 0b11110000) { ch = str[pos] & 0b00000111; chlen = 4; }
		else return false;
		if ((ch == 0 && (chlen != 1 || forbid_nul)) || pos + chlen > len) return false;
		for (uint8_t i = 1; i < chlen; i++) {
			if ((str[pos+i] & 0b11000000) != 0b10000000) return false;
			ch = (ch << 6) | (str[pos+i] & 0b00111111);
		}
		if (ch > 0x10FFFF) return false;
		pos += chlen;
	}
	return true;
}

#define is_valid_utf8(str, len)                 _is_valid_utf8(str, len, false)
#define is_valid_utf8_without_nul(str, len)     _is_valid_utf8(str, len, true)

/* marshal utilities **********************************************************/

static ALWAYS_INLINE void encode_u8le(uint8_t in, uint8_t *out) {
	out[0] = in;
}
static ALWAYS_INLINE void encode_u16le(uint16_t in, uint8_t *out) {
	out[0] = (uint8_t)((in >>  0) & 0xFF);
	out[1] = (uint8_t)((in >>  8) & 0xFF);
}
static ALWAYS_INLINE void encode_u32le(uint32_t in, uint8_t *out) {
	out[0] = (uint8_t)((in >>  0) & 0xFF);
	out[1] = (uint8_t)((in >>  8) & 0xFF);
	out[2] = (uint8_t)((in >> 16) & 0xFF);
	out[3] = (uint8_t)((in >> 24) & 0xFF);
}
static ALWAYS_INLINE void encode_u64le(uint64_t in, uint8_t *out) {
	out[0] = (uint8_t)((in >>  0) & 0xFF);
	out[1] = (uint8_t)((in >>  8) & 0xFF);
	out[2] = (uint8_t)((in >> 16) & 0xFF);
	out[3] = (uint8_t)((in >> 24) & 0xFF);
	out[4] = (uint8_t)((in >> 32) & 0xFF);
	out[5] = (uint8_t)((in >> 40) & 0xFF);
	out[6] = (uint8_t)((in >> 48) & 0xFF);
	out[7] = (uint8_t)((in >> 56) & 0xFF);
}

#endif /* _LIB9P_INTERNAL_H_ */