summaryrefslogtreecommitdiff
path: root/lib9p/internal.h
blob: cbec829426c62f7b53e72a2761b93ebdf3f34034 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
/* lib9p/internal.h - TODO
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-Licence-Identifier: AGPL-3.0-or-later
 */

#ifndef _LIB9P_INTERNAL_H_
#define _LIB9P_INTERNAL_H_

#include <assert.h>
#include <stddef.h> /* for size_t */
#include <limits.h> /* for SSIZE_MAX */

#include <lib9p/9p.h>

#define USE_CONFIG_9P
#define USE_CONFIG_COROUTINE
#include "config.h"
static_assert(CONFIG_9P_MAX_ERR_SIZE <= UINT16_MAX);
static_assert(CONFIG_9P_MAX_MSG_SIZE <= CONFIG_9P_MAX_HOSTMSG_SIZE);
static_assert(CONFIG_9P_MAX_HOSTMSG_SIZE <= SSIZE_MAX);
static_assert(CONFIG_9P_MAX_ERR_SIZE + CONFIG_9P_MAX_MSG_SIZE + 2*CONFIG_9P_MAX_HOSTMSG_SIZE < CONFIG_COROUTINE_DEFAULT_STACK_SIZE);

/* C language *****************************************************************/

#define UNUSED(name)            /* name __attribute__((unused)) */
#define ALWAYS_INLINE           inline __attribute__((always_inline))
#define FLATTEN                 __attribute__((flatten))
#define ARRAY_LEN(arr)          (sizeof(arr)/sizeof((arr)[0]))
#define CAT2(a, b)              a##b
#define CAT3(a, b, c)           a##b##c

/* types **********************************************************************/

/* NB: We declare this here instead of in the public <lib9p/9p.h>
 * because we don't want to include "config.h" from public headers,
 * and I want the MAX_ERR_SIZE to be configurable.  */
struct lib9p_ctx {
	/* negotiated */
	enum lib9p_version      version;
	uint32_t                max_msg_size;

	/* state */
	uint32_t                err_num;
	char                    err_msg[CONFIG_9P_MAX_ERR_SIZE];
};

/* vtables ********************************************************************/

struct _validate_ctx {
	struct lib9p_ctx        *ctx;
	uint32_t                 net_size;
	uint8_t                 *net_bytes;

	uint32_t                 net_offset;
	/* Increment `host_extra` to pre-allocate space that is
	 * "extra" beyond sizeof().  */
	size_t	                 host_extra;
};
typedef bool (*_validate_fn_t)(struct _validate_ctx *ctx);

struct _unmarshal_ctx {
	struct lib9p_ctx        *ctx;
	uint8_t                 *net_bytes;

	uint32_t                 net_offset;
	/* `extra` points to the beginning of unallocated space.  */
	void                    *extra;
};
typedef void (*_unmarshal_fn_t)(struct _unmarshal_ctx *ctx, void *out);

struct _marshal_ctx {
	struct lib9p_ctx        *ctx;

	uint8_t                 *net_bytes;
	uint32_t                 net_offset;
};
typedef bool (*_marshal_fn_t)(struct _marshal_ctx *ctx, void *host_val);

struct _vtable_msg {
	size_t                  basesize;
	_validate_fn_t          validate;
	_unmarshal_fn_t         unmarshal;
	_marshal_fn_t           marshal;
};

struct _vtable_version {
	struct _vtable_msg      msgs[0xFF];
};

extern struct _vtable_version _lib9p_vtables[LIB9P_VER_NUM];

/* unmarshal utilities ********************************************************/

static ALWAYS_INLINE uint8_t decode_u8le(uint8_t *in) {
	return in[0];
}
static ALWAYS_INLINE uint16_t decode_u16le(uint8_t *in) {
	return (((uint16_t)(in[0])) <<	0)
	     | (((uint16_t)(in[1])) <<	8)
	     ;
}
static ALWAYS_INLINE uint32_t decode_u32le(uint8_t *in) {
	return (((uint32_t)(in[0])) <<	0)
	     | (((uint32_t)(in[1])) <<	8)
	     | (((uint32_t)(in[2])) << 16)
	     | (((uint32_t)(in[3])) << 24)
	     ;
}
static ALWAYS_INLINE uint64_t decode_u64le(uint8_t *in) {
	return (((uint64_t)(in[0])) <<	0)
	     | (((uint64_t)(in[1])) <<	8)
	     | (((uint64_t)(in[2])) << 16)
	     | (((uint64_t)(in[3])) << 24)
	     | (((uint64_t)(in[4])) << 32)
	     | (((uint64_t)(in[5])) << 40)
	     | (((uint64_t)(in[6])) << 48)
	     | (((uint64_t)(in[7])) << 56)
	     ;
}

static inline bool _is_valid_utf8(uint8_t *str, size_t len, bool forbid_nul) {
	uint32_t ch;
	uint8_t chlen;
	assert(str);
	for (size_t pos = 0; pos < len;) {
		if      ((str[pos] & 0b10000000) == 0b00000000) { ch = str[pos] & 0b01111111; chlen = 1; }
		else if ((str[pos] & 0b11100000) == 0b11000000) { ch = str[pos] & 0b00011111; chlen = 2; }
		else if ((str[pos] & 0b11110000) == 0b11100000) { ch = str[pos] & 0b00001111; chlen = 3; }
		else if ((str[pos] & 0b11111000) == 0b11110000) { ch = str[pos] & 0b00000111; chlen = 4; }
		else return false;
		if ((ch == 0 && (chlen != 1 || forbid_nul)) || pos + chlen > len) return false;
		for (uint8_t i = 1; i < chlen; i++) {
			if ((str[pos+i] & 0b11000000) != 0b10000000) return false;
			ch = (ch << 6) | (str[pos+i] & 0b00111111);
		}
		if (ch > 0x10FFFF) return false;
		pos += chlen;
	}
	return true;
}

#define is_valid_utf8(str, len)                 _is_valid_utf8(str, len, false)
#define is_valid_utf8_without_nul(str, len)     _is_valid_utf8(str, len, true)

/* marshal utilities **********************************************************/

static ALWAYS_INLINE void encode_u8le(uint8_t in, uint8_t *out) {
	out[0] = in;
}
static ALWAYS_INLINE void encode_u16le(uint16_t in, uint8_t *out) {
	out[0] = (uint8_t)((in >>  0) & 0xFF);
	out[1] = (uint8_t)((in >>  8) & 0xFF);
}
static ALWAYS_INLINE void encode_u32le(uint32_t in, uint8_t *out) {
	out[0] = (uint8_t)((in >>  0) & 0xFF);
	out[1] = (uint8_t)((in >>  8) & 0xFF);
	out[2] = (uint8_t)((in >> 16) & 0xFF);
	out[3] = (uint8_t)((in >> 24) & 0xFF);
}
static ALWAYS_INLINE void encode_u64le(uint64_t in, uint8_t *out) {
	out[0] = (uint8_t)((in >>  0) & 0xFF);
	out[1] = (uint8_t)((in >>  8) & 0xFF);
	out[2] = (uint8_t)((in >> 16) & 0xFF);
	out[3] = (uint8_t)((in >> 24) & 0xFF);
	out[4] = (uint8_t)((in >> 32) & 0xFF);
	out[5] = (uint8_t)((in >> 40) & 0xFF);
	out[6] = (uint8_t)((in >> 48) & 0xFF);
	out[7] = (uint8_t)((in >> 56) & 0xFF);
}

#endif /* _LIB9P_INTERNAL_H_ */