diff options
Diffstat (limited to 'net9p_defs.gen')
-rwxr-xr-x | net9p_defs.gen | 309 |
1 files changed, 309 insertions, 0 deletions
diff --git a/net9p_defs.gen b/net9p_defs.gen new file mode 100755 index 0000000..0e75e42 --- /dev/null +++ b/net9p_defs.gen @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +import enum +import re + + +class Atom(enum.Enum): + u8 = 1 + u16 = 2 + u32 = 3 + u64 = 4 + + +class Struct: + name: str + members: list["Member"] + + +class List: + cnt: str + typ: Atom | Struct + + def __init__(self, /, *, cnt: str, typ: Atom | Struct) -> None: + self.cnt = cnt + self.typ = typ + + +class Member: + name: str + typ: Atom | Struct | List + + def __init__(self, /, *, name: str, typ: Atom | Struct | List) -> None: + self.name = name + self.typ = typ + + +def parse_members( + env: dict[str, Atom | Struct], existing: list[Member], specs: str +) -> list[Member]: + ret = existing + for spec in specs.split(): + m = re.fullmatch(r"(.+)\[([^*]+)(?:\*([^*]+))?\]", spec) + if not m: + raise SyntaxError(f"invalid member spec {repr(spec)}") + if m.group(2) not in env: + raise NameError(f"Unknown type {repr(m.group(2))}") + name = m.group(1) + typ = env[m.group(2)] + if any(x.name == name for x in ret): + raise ValueError(f"duplicate member name {repr(name)}") + if cnt := m.group(3): + if len(ret) == 0 or ret[-1].name != cnt: + raise ValueError(f"list count must be previous item: {repr(cnt)}") + if not isinstance(ret[-1].typ, Atom): + raise ValueError(f"list count must be an integer type: {repr(cnt)}") + ret += [Member(name=name, typ=List(cnt=cnt, typ=typ))] + else: + ret += [Member(name=name, typ=typ)] + return ret + + +class Message: + id: int + name: str + members: list[Member] + + +def parse_file(filename: str) -> tuple[list[Struct], list[Message]]: + msgs: list[Message] = [] + env: dict[str, Atom | Struct] = { + "1": Atom.u8, + "2": Atom.u16, + "4": Atom.u32, + "8": Atom.u64, + } + with open(filename, "r") as fh: + prev: Struct | Message | None = None + for line in fh: + line = line.split("#", 1)[0].strip() + if not line: + continue + if m := re.fullmatch(r'([0-9]+)\s*=\s*(\S+)\s*"([^"]*)"', line): + msg = Message() + msg.id = int(m.group(1)) + msg.name = m.group(2) + msg.members = parse_members(env, [], m.group(3)) + msgs += [msg] + prev = msg + elif m := re.fullmatch(r'(\S+)\s*=\s*"([^"]*)"', line): + struct = Struct() + struct.name = m.group(1) + struct.members = parse_members(env, [], m.group(2)) + env[struct.name] = struct + prev = struct + elif m := re.fullmatch(r'"([^"]*)"', line): + if not prev: + raise SyntaxError( + "a continuation line must come after a struct line" + ) + prev.members = parse_members(env, prev.members, line.strip('"')) + else: + raise SyntaxError(f"invalid line {repr(line)}") + structs = [x for x in env.values() if isinstance(x, Struct)] + return structs, msgs + + +def c_typename(typ: Atom | Struct | List | Message) -> str: + match typ: + case Atom.u8: + return "uint8_t" + case Atom.u16: + return "uint16_t" + case Atom.u32: + return "uint32_t" + case Atom.u64: + return "uint64_t" + case Struct(): + return "struct v9fs_" + typ.name + case Message(): + return "struct v9fs_msg_" + typ.name + case List(): + return c_typename(typ.typ) + "*" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def gen_h(structs: list[Struct], msgs: list[Message]) -> str: + ret = "" + ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n" + ret += "\n" + ret += "#ifndef _NET9P_DEFS_H_\n" + ret += "#define _NET9P_DEFS_H_\n" + ret += "\n" + + for struct in structs: + ret += c_typename(struct) + " {\n" + typewidth = max(len(c_typename(member.typ)) for member in struct.members) + for member in struct.members: + ret += f"\t{c_typename(member.typ).ljust(typewidth)} {member.name};\n" + ret += "};\n" + ret += "\n" + + ret += "enum v9fs_msg_type {\n" + namewidth = max(len(msg.name) for msg in msgs) + for msg in msgs: + ret += f"\tV9FS_TYP_{msg.name.ljust(namewidth)} = {msg.id},\n" + ret += "};\n" + + for msg in msgs: + if not msg.members: + ret += c_typename(msg) + " {};\n" + else: + ret += c_typename(msg) + " {\n" + typewidth = max(len(c_typename(member.typ)) for member in msg.members) + for member in msg.members: + ret += f"\t{c_typename(member.typ).ljust(typewidth)} {member.name};\n" + ret += "};\n" + ret += "\n" + + ret += "#endif /* _NET9P_DEFS_H_ */\n" + return ret + + +c_atom_funcs = """ +static inline uint16_t unmarshal_u16le(uint8_t *bytes) { + return (((uint16_t)(bytes[0])) << 0) + | (((uint16_t)(bytes[1])) << 8) + ; +} +static inline uint32_t unmarshal_u32le(uint8_t *bytes) { + return (((uint16_t)(bytes[0])) << 0) + | (((uint16_t)(bytes[1])) << 8) + | (((uint16_t)(bytes[2])) << 16) + | (((uint16_t)(bytes[3])) << 24) + ; +} +static inline uint64_t unmarshal_u64le(uint8_t *bytes) { + return (((uint16_t)(bytes[0])) << 0) + | (((uint16_t)(bytes[1])) << 8) + | (((uint16_t)(bytes[2])) << 16) + | (((uint16_t)(bytes[3])) << 24) + | (((uint16_t)(bytes[4])) << 32) + | (((uint16_t)(bytes[5])) << 40) + | (((uint16_t)(bytes[6])) << 48) + | (((uint16_t)(bytes[7])) << 56) + ; +} + +static inline void marshal_u16le(val uint16_t, uint8_t *bytes) { + bytes[0] = (uint8_t)((val >> 0) & 0xFF); + bytes[1] = (uint8_t)((val >> 8) & 0xFF); +} +static inline void marshal_u32le(val uint32_t, uint8_t *bytes) { + bytes[0] = (uint8_t)((val >> 0) & 0xFF); + bytes[1] = (uint8_t)((val >> 8) & 0xFF); + bytes[2] = (uint8_t)((val >> 16) & 0xFF); + bytes[3] = (uint8_t)((val >> 24) & 0xFF); +} +static inline void marshal_u64le(val uint64_t, uint8_t *bytes) { + bytes[0] = (uint8_t)((val >> 0) & 0xFF); + bytes[1] = (uint8_t)((val >> 8) & 0xFF); + bytes[2] = (uint8_t)((val >> 16) & 0xFF); + bytes[3] = (uint8_t)((val >> 24) & 0xFF); + bytes[4] = (uint8_t)((val >> 32) & 0xFF); + bytes[5] = (uint8_t)((val >> 40) & 0xFF); + bytes[6] = (uint8_t)((val >> 48) & 0xFF); + bytes[7] = (uint8_t)((val >> 56) & 0xFF); +} +""" + + +def static_net_size(typ: Atom | Struct | List | Message) -> int | None: + match typ: + case Atom.u8: + return 1 + case Atom.u16: + return 2 + case Atom.u32: + return 4 + case Atom.u64: + return 8 + case Struct() | Message(): + size = 0 + for member in typ.members: + msize = static_net_size(member.typ) + if msize is None: + return None + size += msize + return size + case List(): + return None + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def gen_c_check_net_len(msg: Message) -> str: + ret: str = "" + if (sz := static_net_size(msg)) is not None: + ret += f"\tif (net_len != {sz})\n" + ret += "\t\treturn -EINVAL;\n" + return ret + + ret += f"\tuint64_t net_offset = 0;\n" + static_acc = 0 + def _gen_c_check_net_len(prefix: str, struct: Struct | Message) -> str: + nonlocal static_acc + ret: str = "" + + prev_size: int = 0 + for member in struct.members: + if (sz := static_net_size(member.typ)) is not None: + static_acc += sz + prev_size = sz + elif isinstance(member.typ, Struct): + ret += _gen_c_check_net_len(prefix, member.typ) + elif isinstance(member.typ, List): + if static_acc: + ret += f"{prefix}net_offset += {static_acc};\n" + static_acc = 0 + ret += f"{prefix}if ((uint64_t)net_len < net_offset)\n" + ret += f"{prefix}\treturn -EINVAL;\n" + if (sz := static_net_size(member.typ.typ)) is not None: + ret += f"{prefix}net_offset += unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}])*{sz};\n" + else: + assert isinstance(member.typ.typ, Struct) + ret += f"{prefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}]); i < cnt; i++) {{\n" + ret += _gen_c_check_net_len(prefix + "\t", member.typ.typ) + if static_acc: + ret += f"{prefix}net_offset += {static_acc};\n" + static_acc = 0 + ret += f"{prefix}\n" + else: + raise ValueError(f"not a type: {member.typ.__class__.__name__}") + + return ret + ret += _gen_c_check_net_len("\t", msg) + if static_acc: + ret += f"\tnet_offset += {static_acc};\n" + ret += "\tif ((uint64_t)net_len != net_offset)\n" + ret += "\t\treturn -EINVAL;\n" + return ret + + + + +def gen_c(structs: list[Struct], msgs: list[Message]) -> str: + ret = "" + ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n" + ret += "\n" + ret += "#include <stdint.h> /* for size_t, uint{n}_t */\n" + ret += "#include <stdlib.h> /* for malloc() */\n" + ret += "\n" + ret += c_atom_funcs + ret += "\n" + + for msg in msgs: + ret += f"int unmarshal_msg_{msg.name}(uint32_t net_len, uint8_t *net_bytes, {c_typename(msg)} **ret) {{\n" + ret += gen_c_check_net_len(msg) + ret += "\n" + ret += "\tTODO;\n" + ret += "}\n" + ret += "\n" + return ret + + +if __name__ == "__main__": + structs, msgs = parse_file("net9p_defs.txt") + print(gen_h(structs, msgs)) + print(gen_c(structs, msgs)) |