#!/usr/bin/env python import enum import re class Atom(enum.Enum): u8 = 1 u16 = 2 u32 = 3 u64 = 4 class Struct: name: str members: list["Member"] class List: cnt: str typ: Atom | Struct def __init__(self, /, *, cnt: str, typ: Atom | Struct) -> None: self.cnt = cnt self.typ = typ class Member: name: str typ: Atom | Struct | List def __init__(self, /, *, name: str, typ: Atom | Struct | List) -> None: self.name = name self.typ = typ def parse_members( env: dict[str, Atom | Struct], existing: list[Member], specs: str ) -> list[Member]: ret = existing for spec in specs.split(): m = re.fullmatch(r"(.+)\[([^*]+)(?:\*([^*]+))?\]", spec) if not m: raise SyntaxError(f"invalid member spec {repr(spec)}") if m.group(2) not in env: raise NameError(f"Unknown type {repr(m.group(2))}") name = m.group(1) typ = env[m.group(2)] if any(x.name == name for x in ret): raise ValueError(f"duplicate member name {repr(name)}") if cnt := m.group(3): if len(ret) == 0 or ret[-1].name != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") if not isinstance(ret[-1].typ, Atom): raise ValueError(f"list count must be an integer type: {repr(cnt)}") ret += [Member(name=name, typ=List(cnt=cnt, typ=typ))] else: ret += [Member(name=name, typ=typ)] return ret class Message: id: int name: str members: list[Member] def parse_file(filename: str) -> tuple[list[Struct], list[Message]]: msgs: list[Message] = [] env: dict[str, Atom | Struct] = { "1": Atom.u8, "2": Atom.u16, "4": Atom.u32, "8": Atom.u64, } with open(filename, "r") as fh: prev: Struct | Message | None = None for line in fh: line = line.split("#", 1)[0].strip() if not line: continue if m := re.fullmatch(r'([0-9]+)\s*=\s*(\S+)\s*"([^"]*)"', line): msg = Message() msg.id = int(m.group(1)) msg.name = m.group(2) msg.members = parse_members(env, [], m.group(3)) msgs += [msg] prev = msg elif m := re.fullmatch(r'(\S+)\s*=\s*"([^"]*)"', line): struct = Struct() struct.name = m.group(1) struct.members = parse_members(env, [], m.group(2)) env[struct.name] = struct prev = struct elif m := re.fullmatch(r'"([^"]*)"', line): if not prev: raise SyntaxError( "a continuation line must come after a struct line" ) prev.members = parse_members(env, prev.members, line.strip('"')) else: raise SyntaxError(f"invalid line {repr(line)}") structs = [x for x in env.values() if isinstance(x, Struct)] return structs, msgs def c_typename(typ: Atom | Struct | List | Message) -> str: match typ: case Atom.u8: return "uint8_t" case Atom.u16: return "uint16_t" case Atom.u32: return "uint32_t" case Atom.u64: return "uint64_t" case Struct(): return "struct v9fs_" + typ.name case Message(): return "struct v9fs_msg_" + typ.name case List(): return c_typename(typ.typ) + "*" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") def gen_h(structs: list[Struct], msgs: list[Message]) -> str: ret = "" ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n" ret += "\n" ret += "#ifndef _NET9P_DEFS_H_\n" ret += "#define _NET9P_DEFS_H_\n" ret += "\n" for struct in structs: ret += c_typename(struct) + " {\n" typewidth = max(len(c_typename(member.typ)) for member in struct.members) for member in struct.members: ret += f"\t{c_typename(member.typ).ljust(typewidth)} {member.name};\n" ret += "};\n" ret += "\n" ret += "enum v9fs_msg_type {\n" namewidth = max(len(msg.name) for msg in msgs) for msg in msgs: ret += f"\tV9FS_TYP_{msg.name.ljust(namewidth)} = {msg.id},\n" ret += "};\n" for msg in msgs: if not msg.members: ret += c_typename(msg) + " {};\n" else: ret += c_typename(msg) + " {\n" typewidth = max(len(c_typename(member.typ)) for member in msg.members) for member in msg.members: ret += f"\t{c_typename(member.typ).ljust(typewidth)} {member.name};\n" ret += "};\n" ret += "\n" ret += "#endif /* _NET9P_DEFS_H_ */\n" return ret c_atom_funcs = """ static inline uint16_t unmarshal_u16le(uint8_t *bytes) { return (((uint16_t)(bytes[0])) << 0) | (((uint16_t)(bytes[1])) << 8) ; } static inline uint32_t unmarshal_u32le(uint8_t *bytes) { return (((uint16_t)(bytes[0])) << 0) | (((uint16_t)(bytes[1])) << 8) | (((uint16_t)(bytes[2])) << 16) | (((uint16_t)(bytes[3])) << 24) ; } static inline uint64_t unmarshal_u64le(uint8_t *bytes) { return (((uint16_t)(bytes[0])) << 0) | (((uint16_t)(bytes[1])) << 8) | (((uint16_t)(bytes[2])) << 16) | (((uint16_t)(bytes[3])) << 24) | (((uint16_t)(bytes[4])) << 32) | (((uint16_t)(bytes[5])) << 40) | (((uint16_t)(bytes[6])) << 48) | (((uint16_t)(bytes[7])) << 56) ; } static inline void marshal_u16le(val uint16_t, uint8_t *bytes) { bytes[0] = (uint8_t)((val >> 0) & 0xFF); bytes[1] = (uint8_t)((val >> 8) & 0xFF); } static inline void marshal_u32le(val uint32_t, uint8_t *bytes) { bytes[0] = (uint8_t)((val >> 0) & 0xFF); bytes[1] = (uint8_t)((val >> 8) & 0xFF); bytes[2] = (uint8_t)((val >> 16) & 0xFF); bytes[3] = (uint8_t)((val >> 24) & 0xFF); } static inline void marshal_u64le(val uint64_t, uint8_t *bytes) { bytes[0] = (uint8_t)((val >> 0) & 0xFF); bytes[1] = (uint8_t)((val >> 8) & 0xFF); bytes[2] = (uint8_t)((val >> 16) & 0xFF); bytes[3] = (uint8_t)((val >> 24) & 0xFF); bytes[4] = (uint8_t)((val >> 32) & 0xFF); bytes[5] = (uint8_t)((val >> 40) & 0xFF); bytes[6] = (uint8_t)((val >> 48) & 0xFF); bytes[7] = (uint8_t)((val >> 56) & 0xFF); } """ def static_net_size(typ: Atom | Struct | List | Message) -> int | None: match typ: case Atom.u8: return 1 case Atom.u16: return 2 case Atom.u32: return 4 case Atom.u64: return 8 case Struct() | Message(): size = 0 for member in typ.members: msize = static_net_size(member.typ) if msize is None: return None size += msize return size case List(): return None case _: raise ValueError(f"not a type: {typ.__class__.__name__}") def gen_c_check_net_len(msg: Message) -> str: ret: str = "" if (sz := static_net_size(msg)) is not None: ret += f"\tif (net_len != {sz})\n" ret += "\t\treturn -EINVAL;\n" return ret ret += f"\tuint64_t net_offset = 0;\n" static_acc = 0 def _gen_c_check_net_len(prefix: str, struct: Struct | Message) -> str: nonlocal static_acc ret: str = "" prev_size: int = 0 for member in struct.members: if (sz := static_net_size(member.typ)) is not None: static_acc += sz prev_size = sz elif isinstance(member.typ, Struct): ret += _gen_c_check_net_len(prefix, member.typ) elif isinstance(member.typ, List): if static_acc: ret += f"{prefix}net_offset += {static_acc};\n" static_acc = 0 ret += f"{prefix}if ((uint64_t)net_len < net_offset)\n" ret += f"{prefix}\treturn -EINVAL;\n" if (sz := static_net_size(member.typ.typ)) is not None: ret += f"{prefix}net_offset += unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}])*{sz};\n" else: assert isinstance(member.typ.typ, Struct) ret += f"{prefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}]); i < cnt; i++) {{\n" ret += _gen_c_check_net_len(prefix + "\t", member.typ.typ) if static_acc: ret += f"{prefix}net_offset += {static_acc};\n" static_acc = 0 ret += f"{prefix}\n" else: raise ValueError(f"not a type: {member.typ.__class__.__name__}") return ret ret += _gen_c_check_net_len("\t", msg) if static_acc: ret += f"\tnet_offset += {static_acc};\n" ret += "\tif ((uint64_t)net_len != net_offset)\n" ret += "\t\treturn -EINVAL;\n" return ret def gen_c(structs: list[Struct], msgs: list[Message]) -> str: ret = "" ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n" ret += "\n" ret += "#include /* for size_t, uint{n}_t */\n" ret += "#include /* for malloc() */\n" ret += "\n" ret += c_atom_funcs ret += "\n" for msg in msgs: ret += f"int unmarshal_msg_{msg.name}(uint32_t net_len, uint8_t *net_bytes, {c_typename(msg)} **ret) {{\n" ret += gen_c_check_net_len(msg) ret += "\n" ret += "\tTODO;\n" ret += "}\n" ret += "\n" return ret if __name__ == "__main__": structs, msgs = parse_file("net9p_defs.txt") print(gen_h(structs, msgs)) print(gen_c(structs, msgs))