summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-09-20 14:38:23 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-09-20 14:38:23 -0600
commit409242d2ee048fa6a5be4d893cce9a58d7fd73cb (patch)
tree756ef055cc2aa92259cf6b04d2b80eea8120f391
parent2caf7b05a1de8fcad42e33159008ca8f203f8caf (diff)
wip
-rwxr-xr-xnet9p_defs.gen324
1 files changed, 169 insertions, 155 deletions
diff --git a/net9p_defs.gen b/net9p_defs.gen
index 8e8fe27..9582923 100755
--- a/net9p_defs.gen
+++ b/net9p_defs.gen
@@ -4,6 +4,8 @@ import enum
import re
+# Parse net9p_defs.txt #########################################################
+
class Atom(enum.Enum):
u8 = 1
u16 = 2
@@ -103,6 +105,24 @@ def parse_file(filename: str) -> tuple[list[Struct], list[Message]]:
structs = [x for x in env.values() if isinstance(x, Struct)]
return structs, msgs
+# Generate C ###################################################################
+
+def shortname(typ: Atom | Struct | Message) -> str:
+ match typ:
+ case Atom.u8:
+ return "1"
+ case Atom.u16:
+ return "2"
+ case Atom.u32:
+ return "4"
+ case Atom.u64:
+ return "8"
+ case Struct():
+ return typ.name
+ case Message():
+ return 'msg_'+typ.name
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
def c_typename(typ: Atom | Struct | List | Message) -> str:
match typ:
@@ -124,6 +144,30 @@ def c_typename(typ: Atom | Struct | List | Message) -> str:
raise ValueError(f"not a type: {typ.__class__.__name__}")
+def static_size(typ: Atom | Struct | List | Message) -> int | None:
+ match typ:
+ case Atom.u8:
+ return 1
+ case Atom.u16:
+ return 2
+ case Atom.u32:
+ return 4
+ case Atom.u64:
+ return 8
+ case Struct() | Message():
+ size = 0
+ for member in typ.members:
+ msize = static_size(member.typ)
+ if msize is None:
+ return None
+ size += msize
+ return size
+ case List():
+ return None
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
def gen_h(structs: list[Struct], msgs: list[Message]) -> str:
ret = ""
ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n"
@@ -161,42 +205,57 @@ def gen_h(structs: list[Struct], msgs: list[Message]) -> str:
return ret
-c_atom_funcs = """
-static inline uint16_t unmarshal_u16le(uint8_t *bytes) {
- return (((uint16_t)(bytes[0])) << 0)
- | (((uint16_t)(bytes[1])) << 8)
- ;
+def gen_c(structs: list[Struct], msgs: list[Message]) -> str:
+ ret = """
+/* Generated by ./net9p_defs.gen. DO NOT EDIT! */
+
+#include <stdint.h> /* for size_t, uint{n}_t */
+#include <stdlib.h> /* for malloc() */
+
+#include "net9p_defs.h"
+"""
+
+ # basic utilities ##########################################################
+ ret += """
+/* basic utilities ************************************************************/
+
+#define UNUSED __attribute__ ((unused))
+
+static inline uint16_t decode_u16le(uint8_t *bytes) {
+ return (((uint16_t)(bytes[0])) << 0)
+ | (((uint16_t)(bytes[1])) << 8)
+ ;
}
-static inline uint32_t unmarshal_u32le(uint8_t *bytes) {
- return (((uint16_t)(bytes[0])) << 0)
- | (((uint16_t)(bytes[1])) << 8)
- | (((uint16_t)(bytes[2])) << 16)
- | (((uint16_t)(bytes[3])) << 24)
- ;
+static inline uint32_t decode_u32le(uint8_t *bytes) {
+ return (((uint16_t)(bytes[0])) << 0)
+ | (((uint16_t)(bytes[1])) << 8)
+ | (((uint16_t)(bytes[2])) << 16)
+ | (((uint16_t)(bytes[3])) << 24)
+ ;
}
-static inline uint64_t unmarshal_u64le(uint8_t *bytes) {
- return (((uint16_t)(bytes[0])) << 0)
- | (((uint16_t)(bytes[1])) << 8)
- | (((uint16_t)(bytes[2])) << 16)
- | (((uint16_t)(bytes[3])) << 24)
- | (((uint16_t)(bytes[4])) << 32)
- | (((uint16_t)(bytes[5])) << 40)
- | (((uint16_t)(bytes[6])) << 48)
- | (((uint16_t)(bytes[7])) << 56)
- ;
+static inline uint64_t decode_u64le(uint8_t *bytes) {
+ return (((uint16_t)(bytes[0])) << 0)
+ | (((uint16_t)(bytes[1])) << 8)
+ | (((uint16_t)(bytes[2])) << 16)
+ | (((uint16_t)(bytes[3])) << 24)
+ | (((uint16_t)(bytes[4])) << 32)
+ | (((uint16_t)(bytes[5])) << 40)
+ | (((uint16_t)(bytes[6])) << 48)
+ | (((uint16_t)(bytes[7])) << 56)
+ ;
}
-static inline void marshal_u16le(val uint16_t, uint8_t *bytes) {
+static inline void decode_u16le(val uint16_t, uint8_t *bytes) {
bytes[0] = (uint8_t)((val >> 0) & 0xFF);
bytes[1] = (uint8_t)((val >> 8) & 0xFF);
}
-static inline void marshal_u32le(val uint32_t, uint8_t *bytes) {
+static inline void decode_u32le(val uint32_t, uint8_t *bytes) {
bytes[0] = (uint8_t)((val >> 0) & 0xFF);
bytes[1] = (uint8_t)((val >> 8) & 0xFF);
bytes[2] = (uint8_t)((val >> 16) & 0xFF);
bytes[3] = (uint8_t)((val >> 24) & 0xFF);
}
-static inline void marshal_u64le(val uint64_t, uint8_t *bytes) {
+static inline void decode_u64le(val uint64_t, uint8_t *bytes) {
bytes[0] = (uint8_t)((val >> 0) & 0xFF);
bytes[1] = (uint8_t)((val >> 8) & 0xFF);
bytes[2] = (uint8_t)((val >> 16) & 0xFF);
@@ -208,144 +267,99 @@ static inline void marshal_u64le(val uint64_t, uint8_t *bytes) {
}
"""
-
-def static_net_size(typ: Atom | Struct | List | Message) -> int | None:
- match typ:
- case Atom.u8:
- return 1
- case Atom.u16:
- return 2
- case Atom.u32:
- return 4
- case Atom.u64:
- return 8
- case Struct() | Message():
- size = 0
- for member in typ.members:
- msize = static_net_size(member.typ)
- if msize is None:
- return None
- size += msize
- return size
- case List():
- return None
- case _:
- raise ValueError(f"not a type: {typ.__class__.__name__}")
-
-
-def gen_c_check_net_len(msg: Message) -> str:
- ret: str = ""
- if (sz := static_net_size(msg)) is not None:
- ret += f"\tif (net_len != {sz})\n"
- ret += "\t\treturn -EINVAL;\n"
- return ret
-
- def has_big_list(typ: Atom | Struct | List | Message) -> bool:
- match typ:
- case Atom.u8:
- return False
- case Atom.u16:
- return False
- case Atom.u32:
- return False
- case Atom.u64:
- return False
- case Struct() | Message():
- return any(has_big_list(member.typ) for member in typ.members)
- case List():
- sz = static_net_size(typ.typ)
- if sz is None:
- return False
- return sz > 1
- case _:
- raise ValueError(f"not a type: {typ.__class__.__name__}")
-
- inited: bool = False
- static_acc: int = 0
- static_acc_name: list[str] = []
- def _gen_c_check_net_len(wsprefix: str, nameprefix: str, struct: Struct | Message) -> str:
- nonlocal inited
- nonlocal static_acc
- nonlocal static_acc_name
- ret: str = ""
-
- if has_big_list(struct):
- ret += f"{wsprefix}uint32_t sizeof_array;\n"
- prev_size: int = 0
+ # checksize ################################################################
+ ret += """
+/* checksize ******************************************************************/
+
+typedef bool (*_checksize_fn_t)(uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, size_t *mut_host_extra);
+static inline bool _checksize_list(size_t cnt, checksize_fn_t fn, size_t host_size,
+ uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, size_t *mut_host_extra) {
+ for (size_t i = 0; i < cnt; i++)
+ if (__builtin_add_overflow(*mut_host_extra, host_size, mut_host_extra)
+ || fn(net_len, net_bytes, mut_net_offset, mut_host_extra))
+ return true;
+ return false;
+}
+static inline bool checksize_1(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) {
+ return __builtin_add_overflow(*mut_net_offset, 1, mut_net_offset) || net_len < *mut_net_offset;
+}
+static inline bool checksize_2(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) {
+ return __builtin_add_overflow(*mut_net_offset, 2, mut_net_offset) || net_len < *mut_net_offset;
+}
+static inline bool checksize_4(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) {
+ return __builtin_add_overflow(*mut_net_offset, 4, mut_net_offset) || net_len < *mut_net_offset;
+}
+static inline bool checksize_8(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) {
+ return __builtin_add_overflow(*mut_net_offset, 8, mut_net_offset) || net_len < *mut_net_offset;
+}
+"""
+ for struct in structs + msgs:
+ argattr = ' UNUSED' if len(struct.members) == 0 else ''
+ ret += f"static inline bool checksize_{shortname(struct)}(uint32_t net_len{argattr}, uint8_t *net_bytes{argattr}, uint32_t *mut_net_offset{argattr}, size_t mut_host_extra{argattr}) {{"
+ if len(struct.members) == 0:
+ ret += "}\n"
+ continue
+ prefix0 = "\treturn "
+ prefix1 = "\t || "
+ prefix2 = "\t "
+ prefix = prefix0
+ prev_size = 0
for member in struct.members:
- if (sz := static_net_size(member.typ)) is not None:
- static_acc_name += [nameprefix+member.name]
- static_acc += sz
- prev_size = sz
- elif isinstance(member.typ, Struct):
- ret += _gen_c_check_net_len(wsprefix, nameprefix+member.name+'.', member.typ)
- elif isinstance(member.typ, List):
- if not inited:
- ret += f"{wsprefix}uint32_t calced_net_len = {static_acc};"
- if static_acc:
- ret += f" /* {', '.join(static_acc_name)} */"
- ret += "\n"
- static_acc_name = []
- static_acc = 0
- inited = True
- elif static_acc:
- ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n"
- static_acc_name = []
- static_acc = 0
- ret += f"{wsprefix}/* {nameprefix}{member.name} */\n"
- ret += f"{wsprefix}if (net_len < calced_net_len) return -EWRONGLEN;\n"
- if (sz := static_net_size(member.typ.typ)) is not None:
- if sz == 1:
- ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), &calced_net_len))\n"
- else:
- ret += f"{wsprefix}if (__builtin_mul_overflow(unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), {sz}, &sizeof_array)\n"
- ret += f"{wsprefix} || __builtin_add_overflow(calced_net_len, sizeof_array, &calced_net_len)\n"
- ret += f"{wsprefix}\treturn -EWRONGLEN;\n"
- else:
- assert isinstance(member.typ.typ, Struct)
- ret += f"{wsprefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]); i < cnt; i++) {{\n"
- ret += _gen_c_check_net_len(wsprefix + "\t", nameprefix+member.name+'[i].', member.typ.typ)
- if static_acc:
- ret += f"{wsprefix}\tif (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n"
- static_acc_name = []
- static_acc = 0
- ret += f"{wsprefix}\n"
+ if isinstance(member.typ, List):
+ ret += f"\n{prefix }_checksize_list(decode_u{prev_size*8}le(&net_bytes[(*mut_net_offset)-{prev_size}]), checksize_{shortname(member.typ.typ)}, sizeof({c_typename(member.typ.typ)}),"
+ ret += f"\n{prefix2} net_len, net_bytes, mut_net_offset, mut_host_extra)"
else:
- raise ValueError(f"not a type: {member.typ.__class__.__name__}")
-
- return ret
- ret += _gen_c_check_net_len('\t', '.', msg)
- if static_acc:
- ret += f"\tcalced_net_len += {static_acc}; /* {', '.join(static_acc_name)} */\n"
- ret += "\t/* final check */\n"
- ret += "\tif (net_len != calced_net_len)\n"
- ret += "\t\treturn -EINVAL;\n"
- return ret
-
-
-
-
-def gen_c(structs: list[Struct], msgs: list[Message]) -> str:
- ret = ""
- ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n"
- ret += "\n"
- ret += "#include <stdint.h> /* for size_t, uint{n}_t */\n"
- ret += "#include <stdlib.h> /* for malloc() */\n"
- ret += "\n"
- ret += c_atom_funcs
- ret += "\n"
-
- for msg in msgs:
- ret += f"int unmarshal_msg_{msg.name}(uint32_t net_len, uint8_t *net_bytes, {c_typename(msg)} **ret) {{\n"
- ret += gen_c_check_net_len(msg)
+ ret += f"\n{prefix}checksize_{shortname(member.typ)}(net_len, net_bytes, mut_net_offset, mut_host_extra)"
+ prefix = prefix1
+ if struct.name == 's':
+ ret += f"\n{prefix}__builtin_add_overflow(*mut_host_extra, 1, mut_host_extra)"
+ ret += ";\n}\n"
+
+ # unmarshal ################################################################
+ ret += """
+/* unmarshal ******************************************************************/
+/* checksize_XXX() should be called before unmarshal_XXX(). */
+
+static inline void unmarshal_1(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint8_t *out) {
+ *out = decode_u8le(&net_bytes[*mut_net_offset]);
+ *mut_net_offset += 1;
+}
+static inline void unmarshal_2(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint16_t *out) {
+ *out = decode_u16le(&net_bytes[*mut_net_offset]);
+ *mut_net_offset += 2;
+}
+static inline void unmarshal_4(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint32_t *out) {
+ *out = decode_u32le(&net_bytes[*mut_net_offset]);
+ *mut_net_offset += 4;
+}
+static inline void unmarshal_8(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint64_t *out) {
+ *out = decode_u64le(&net_bytes[*mut_net_offset]);
+ *mut_net_offset += 8;
+}
+"""
+ for struct in structs + msgs:
+ ret += f"static inline void unmarshal_{shortname(struct)}(uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, void *mut_host_extra, {c_typename(struct)} *out) {{"
+ if len(struct.members) == 0:
+ ret += "}\n"
+ continue
ret += "\n"
- ret += "\tTODO;\n"
+ for member in struct.members:
+ if isinstance(member.typ, List):
+ ret += f"\tout->{member.name} = mut_host_extra;\n"
+ ret += f"\t*mut_host_extra += sizeof(out->{member.name}) * out->{member.typ.cnt};\n"
+ ret += f"\tfor (typeof(out->{member.typ.cnt}) i = 0; i < out->{member.typ.cnt}; i++)\n"
+ ret += f"\t\tunmarshal_{shortname(member.typ.typ)}(net_len, net_bytes, mut_net_offset, mut_host_extra, &(out->{member.name}[i]));\n"
+ else:
+ ret += f"\tunmarshal_{shortname(member.typ)}(net_len, net_bytes, mut_net_offset, mut_host_extra, &(out->{member.name}));\n"
ret += "}\n"
- ret += "\n"
+
+ ############################################################################
return ret
+################################################################################
+
if __name__ == "__main__":
structs, msgs = parse_file("net9p_defs.txt")
- print(gen_h(structs, msgs))
+ #print(gen_h(structs, msgs))
print(gen_c(structs, msgs))