diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-09-20 14:38:23 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-09-20 14:38:23 -0600 |
commit | 409242d2ee048fa6a5be4d893cce9a58d7fd73cb (patch) | |
tree | 756ef055cc2aa92259cf6b04d2b80eea8120f391 | |
parent | 2caf7b05a1de8fcad42e33159008ca8f203f8caf (diff) |
wip
-rwxr-xr-x | net9p_defs.gen | 324 |
1 files changed, 169 insertions, 155 deletions
diff --git a/net9p_defs.gen b/net9p_defs.gen index 8e8fe27..9582923 100755 --- a/net9p_defs.gen +++ b/net9p_defs.gen @@ -4,6 +4,8 @@ import enum import re +# Parse net9p_defs.txt ######################################################### + class Atom(enum.Enum): u8 = 1 u16 = 2 @@ -103,6 +105,24 @@ def parse_file(filename: str) -> tuple[list[Struct], list[Message]]: structs = [x for x in env.values() if isinstance(x, Struct)] return structs, msgs +# Generate C ################################################################### + +def shortname(typ: Atom | Struct | Message) -> str: + match typ: + case Atom.u8: + return "1" + case Atom.u16: + return "2" + case Atom.u32: + return "4" + case Atom.u64: + return "8" + case Struct(): + return typ.name + case Message(): + return 'msg_'+typ.name + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") def c_typename(typ: Atom | Struct | List | Message) -> str: match typ: @@ -124,6 +144,30 @@ def c_typename(typ: Atom | Struct | List | Message) -> str: raise ValueError(f"not a type: {typ.__class__.__name__}") +def static_size(typ: Atom | Struct | List | Message) -> int | None: + match typ: + case Atom.u8: + return 1 + case Atom.u16: + return 2 + case Atom.u32: + return 4 + case Atom.u64: + return 8 + case Struct() | Message(): + size = 0 + for member in typ.members: + msize = static_size(member.typ) + if msize is None: + return None + size += msize + return size + case List(): + return None + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + def gen_h(structs: list[Struct], msgs: list[Message]) -> str: ret = "" ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n" @@ -161,42 +205,57 @@ def gen_h(structs: list[Struct], msgs: list[Message]) -> str: return ret -c_atom_funcs = """ -static inline uint16_t unmarshal_u16le(uint8_t *bytes) { - return (((uint16_t)(bytes[0])) << 0) - | (((uint16_t)(bytes[1])) << 8) - ; +def gen_c(structs: list[Struct], msgs: list[Message]) -> str: + ret = """ +/* Generated by ./net9p_defs.gen. DO NOT EDIT! */ + +#include <stdint.h> /* for size_t, uint{n}_t */ +#include <stdlib.h> /* for malloc() */ + +#include "net9p_defs.h" +""" + + # basic utilities ########################################################## + ret += """ +/* basic utilities ************************************************************/ + +#define UNUSED __attribute__ ((unused)) + +static inline uint16_t decode_u16le(uint8_t *bytes) { + return (((uint16_t)(bytes[0])) << 0) + | (((uint16_t)(bytes[1])) << 8) + ; } -static inline uint32_t unmarshal_u32le(uint8_t *bytes) { - return (((uint16_t)(bytes[0])) << 0) - | (((uint16_t)(bytes[1])) << 8) - | (((uint16_t)(bytes[2])) << 16) - | (((uint16_t)(bytes[3])) << 24) - ; +static inline uint32_t decode_u32le(uint8_t *bytes) { + return (((uint16_t)(bytes[0])) << 0) + | (((uint16_t)(bytes[1])) << 8) + | (((uint16_t)(bytes[2])) << 16) + | (((uint16_t)(bytes[3])) << 24) + ; } -static inline uint64_t unmarshal_u64le(uint8_t *bytes) { - return (((uint16_t)(bytes[0])) << 0) - | (((uint16_t)(bytes[1])) << 8) - | (((uint16_t)(bytes[2])) << 16) - | (((uint16_t)(bytes[3])) << 24) - | (((uint16_t)(bytes[4])) << 32) - | (((uint16_t)(bytes[5])) << 40) - | (((uint16_t)(bytes[6])) << 48) - | (((uint16_t)(bytes[7])) << 56) - ; +static inline uint64_t decode_u64le(uint8_t *bytes) { + return (((uint16_t)(bytes[0])) << 0) + | (((uint16_t)(bytes[1])) << 8) + | (((uint16_t)(bytes[2])) << 16) + | (((uint16_t)(bytes[3])) << 24) + | (((uint16_t)(bytes[4])) << 32) + | (((uint16_t)(bytes[5])) << 40) + | (((uint16_t)(bytes[6])) << 48) + | (((uint16_t)(bytes[7])) << 56) + ; } -static inline void marshal_u16le(val uint16_t, uint8_t *bytes) { +static inline void decode_u16le(val uint16_t, uint8_t *bytes) { bytes[0] = (uint8_t)((val >> 0) & 0xFF); bytes[1] = (uint8_t)((val >> 8) & 0xFF); } -static inline void marshal_u32le(val uint32_t, uint8_t *bytes) { +static inline void decode_u32le(val uint32_t, uint8_t *bytes) { bytes[0] = (uint8_t)((val >> 0) & 0xFF); bytes[1] = (uint8_t)((val >> 8) & 0xFF); bytes[2] = (uint8_t)((val >> 16) & 0xFF); bytes[3] = (uint8_t)((val >> 24) & 0xFF); } -static inline void marshal_u64le(val uint64_t, uint8_t *bytes) { +static inline void decode_u64le(val uint64_t, uint8_t *bytes) { bytes[0] = (uint8_t)((val >> 0) & 0xFF); bytes[1] = (uint8_t)((val >> 8) & 0xFF); bytes[2] = (uint8_t)((val >> 16) & 0xFF); @@ -208,144 +267,99 @@ static inline void marshal_u64le(val uint64_t, uint8_t *bytes) { } """ - -def static_net_size(typ: Atom | Struct | List | Message) -> int | None: - match typ: - case Atom.u8: - return 1 - case Atom.u16: - return 2 - case Atom.u32: - return 4 - case Atom.u64: - return 8 - case Struct() | Message(): - size = 0 - for member in typ.members: - msize = static_net_size(member.typ) - if msize is None: - return None - size += msize - return size - case List(): - return None - case _: - raise ValueError(f"not a type: {typ.__class__.__name__}") - - -def gen_c_check_net_len(msg: Message) -> str: - ret: str = "" - if (sz := static_net_size(msg)) is not None: - ret += f"\tif (net_len != {sz})\n" - ret += "\t\treturn -EINVAL;\n" - return ret - - def has_big_list(typ: Atom | Struct | List | Message) -> bool: - match typ: - case Atom.u8: - return False - case Atom.u16: - return False - case Atom.u32: - return False - case Atom.u64: - return False - case Struct() | Message(): - return any(has_big_list(member.typ) for member in typ.members) - case List(): - sz = static_net_size(typ.typ) - if sz is None: - return False - return sz > 1 - case _: - raise ValueError(f"not a type: {typ.__class__.__name__}") - - inited: bool = False - static_acc: int = 0 - static_acc_name: list[str] = [] - def _gen_c_check_net_len(wsprefix: str, nameprefix: str, struct: Struct | Message) -> str: - nonlocal inited - nonlocal static_acc - nonlocal static_acc_name - ret: str = "" - - if has_big_list(struct): - ret += f"{wsprefix}uint32_t sizeof_array;\n" - prev_size: int = 0 + # checksize ################################################################ + ret += """ +/* checksize ******************************************************************/ + +typedef bool (*_checksize_fn_t)(uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, size_t *mut_host_extra); +static inline bool _checksize_list(size_t cnt, checksize_fn_t fn, size_t host_size, + uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, size_t *mut_host_extra) { + for (size_t i = 0; i < cnt; i++) + if (__builtin_add_overflow(*mut_host_extra, host_size, mut_host_extra) + || fn(net_len, net_bytes, mut_net_offset, mut_host_extra)) + return true; + return false; +} +static inline bool checksize_1(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) { + return __builtin_add_overflow(*mut_net_offset, 1, mut_net_offset) || net_len < *mut_net_offset; +} +static inline bool checksize_2(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) { + return __builtin_add_overflow(*mut_net_offset, 2, mut_net_offset) || net_len < *mut_net_offset; +} +static inline bool checksize_4(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) { + return __builtin_add_overflow(*mut_net_offset, 4, mut_net_offset) || net_len < *mut_net_offset; +} +static inline bool checksize_8(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, size_t *mut_host_extra UNUSED) { + return __builtin_add_overflow(*mut_net_offset, 8, mut_net_offset) || net_len < *mut_net_offset; +} +""" + for struct in structs + msgs: + argattr = ' UNUSED' if len(struct.members) == 0 else '' + ret += f"static inline bool checksize_{shortname(struct)}(uint32_t net_len{argattr}, uint8_t *net_bytes{argattr}, uint32_t *mut_net_offset{argattr}, size_t mut_host_extra{argattr}) {{" + if len(struct.members) == 0: + ret += "}\n" + continue + prefix0 = "\treturn " + prefix1 = "\t || " + prefix2 = "\t " + prefix = prefix0 + prev_size = 0 for member in struct.members: - if (sz := static_net_size(member.typ)) is not None: - static_acc_name += [nameprefix+member.name] - static_acc += sz - prev_size = sz - elif isinstance(member.typ, Struct): - ret += _gen_c_check_net_len(wsprefix, nameprefix+member.name+'.', member.typ) - elif isinstance(member.typ, List): - if not inited: - ret += f"{wsprefix}uint32_t calced_net_len = {static_acc};" - if static_acc: - ret += f" /* {', '.join(static_acc_name)} */" - ret += "\n" - static_acc_name = [] - static_acc = 0 - inited = True - elif static_acc: - ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n" - static_acc_name = [] - static_acc = 0 - ret += f"{wsprefix}/* {nameprefix}{member.name} */\n" - ret += f"{wsprefix}if (net_len < calced_net_len) return -EWRONGLEN;\n" - if (sz := static_net_size(member.typ.typ)) is not None: - if sz == 1: - ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), &calced_net_len))\n" - else: - ret += f"{wsprefix}if (__builtin_mul_overflow(unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), {sz}, &sizeof_array)\n" - ret += f"{wsprefix} || __builtin_add_overflow(calced_net_len, sizeof_array, &calced_net_len)\n" - ret += f"{wsprefix}\treturn -EWRONGLEN;\n" - else: - assert isinstance(member.typ.typ, Struct) - ret += f"{wsprefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]); i < cnt; i++) {{\n" - ret += _gen_c_check_net_len(wsprefix + "\t", nameprefix+member.name+'[i].', member.typ.typ) - if static_acc: - ret += f"{wsprefix}\tif (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n" - static_acc_name = [] - static_acc = 0 - ret += f"{wsprefix}\n" + if isinstance(member.typ, List): + ret += f"\n{prefix }_checksize_list(decode_u{prev_size*8}le(&net_bytes[(*mut_net_offset)-{prev_size}]), checksize_{shortname(member.typ.typ)}, sizeof({c_typename(member.typ.typ)})," + ret += f"\n{prefix2} net_len, net_bytes, mut_net_offset, mut_host_extra)" else: - raise ValueError(f"not a type: {member.typ.__class__.__name__}") - - return ret - ret += _gen_c_check_net_len('\t', '.', msg) - if static_acc: - ret += f"\tcalced_net_len += {static_acc}; /* {', '.join(static_acc_name)} */\n" - ret += "\t/* final check */\n" - ret += "\tif (net_len != calced_net_len)\n" - ret += "\t\treturn -EINVAL;\n" - return ret - - - - -def gen_c(structs: list[Struct], msgs: list[Message]) -> str: - ret = "" - ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n" - ret += "\n" - ret += "#include <stdint.h> /* for size_t, uint{n}_t */\n" - ret += "#include <stdlib.h> /* for malloc() */\n" - ret += "\n" - ret += c_atom_funcs - ret += "\n" - - for msg in msgs: - ret += f"int unmarshal_msg_{msg.name}(uint32_t net_len, uint8_t *net_bytes, {c_typename(msg)} **ret) {{\n" - ret += gen_c_check_net_len(msg) + ret += f"\n{prefix}checksize_{shortname(member.typ)}(net_len, net_bytes, mut_net_offset, mut_host_extra)" + prefix = prefix1 + if struct.name == 's': + ret += f"\n{prefix}__builtin_add_overflow(*mut_host_extra, 1, mut_host_extra)" + ret += ";\n}\n" + + # unmarshal ################################################################ + ret += """ +/* unmarshal ******************************************************************/ +/* checksize_XXX() should be called before unmarshal_XXX(). */ + +static inline void unmarshal_1(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint8_t *out) { + *out = decode_u8le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 1; +} +static inline void unmarshal_2(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint16_t *out) { + *out = decode_u16le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 2; +} +static inline void unmarshal_4(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint32_t *out) { + *out = decode_u32le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 4; +} +static inline void unmarshal_8(uint32_t net_len, uint8_t *net_bytes UNUSED, uint32_t *mut_net_offset, void *mut_host_extra UNUSED, uint64_t *out) { + *out = decode_u64le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 8; +} +""" + for struct in structs + msgs: + ret += f"static inline void unmarshal_{shortname(struct)}(uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, void *mut_host_extra, {c_typename(struct)} *out) {{" + if len(struct.members) == 0: + ret += "}\n" + continue ret += "\n" - ret += "\tTODO;\n" + for member in struct.members: + if isinstance(member.typ, List): + ret += f"\tout->{member.name} = mut_host_extra;\n" + ret += f"\t*mut_host_extra += sizeof(out->{member.name}) * out->{member.typ.cnt};\n" + ret += f"\tfor (typeof(out->{member.typ.cnt}) i = 0; i < out->{member.typ.cnt}; i++)\n" + ret += f"\t\tunmarshal_{shortname(member.typ.typ)}(net_len, net_bytes, mut_net_offset, mut_host_extra, &(out->{member.name}[i]));\n" + else: + ret += f"\tunmarshal_{shortname(member.typ)}(net_len, net_bytes, mut_net_offset, mut_host_extra, &(out->{member.name}));\n" ret += "}\n" - ret += "\n" + + ############################################################################ return ret +################################################################################ + if __name__ == "__main__": structs, msgs = parse_file("net9p_defs.txt") - print(gen_h(structs, msgs)) + #print(gen_h(structs, msgs)) print(gen_c(structs, msgs)) |