summaryrefslogtreecommitdiff
path: root/net9p_defs.gen
diff options
context:
space:
mode:
Diffstat (limited to 'net9p_defs.gen')
-rwxr-xr-xnet9p_defs.gen309
1 files changed, 309 insertions, 0 deletions
diff --git a/net9p_defs.gen b/net9p_defs.gen
new file mode 100755
index 0000000..0e75e42
--- /dev/null
+++ b/net9p_defs.gen
@@ -0,0 +1,309 @@
+#!/usr/bin/env python
+
+import enum
+import re
+
+
+class Atom(enum.Enum):
+ u8 = 1
+ u16 = 2
+ u32 = 3
+ u64 = 4
+
+
+class Struct:
+ name: str
+ members: list["Member"]
+
+
+class List:
+ cnt: str
+ typ: Atom | Struct
+
+ def __init__(self, /, *, cnt: str, typ: Atom | Struct) -> None:
+ self.cnt = cnt
+ self.typ = typ
+
+
+class Member:
+ name: str
+ typ: Atom | Struct | List
+
+ def __init__(self, /, *, name: str, typ: Atom | Struct | List) -> None:
+ self.name = name
+ self.typ = typ
+
+
+def parse_members(
+ env: dict[str, Atom | Struct], existing: list[Member], specs: str
+) -> list[Member]:
+ ret = existing
+ for spec in specs.split():
+ m = re.fullmatch(r"(.+)\[([^*]+)(?:\*([^*]+))?\]", spec)
+ if not m:
+ raise SyntaxError(f"invalid member spec {repr(spec)}")
+ if m.group(2) not in env:
+ raise NameError(f"Unknown type {repr(m.group(2))}")
+ name = m.group(1)
+ typ = env[m.group(2)]
+ if any(x.name == name for x in ret):
+ raise ValueError(f"duplicate member name {repr(name)}")
+ if cnt := m.group(3):
+ if len(ret) == 0 or ret[-1].name != cnt:
+ raise ValueError(f"list count must be previous item: {repr(cnt)}")
+ if not isinstance(ret[-1].typ, Atom):
+ raise ValueError(f"list count must be an integer type: {repr(cnt)}")
+ ret += [Member(name=name, typ=List(cnt=cnt, typ=typ))]
+ else:
+ ret += [Member(name=name, typ=typ)]
+ return ret
+
+
+class Message:
+ id: int
+ name: str
+ members: list[Member]
+
+
+def parse_file(filename: str) -> tuple[list[Struct], list[Message]]:
+ msgs: list[Message] = []
+ env: dict[str, Atom | Struct] = {
+ "1": Atom.u8,
+ "2": Atom.u16,
+ "4": Atom.u32,
+ "8": Atom.u64,
+ }
+ with open(filename, "r") as fh:
+ prev: Struct | Message | None = None
+ for line in fh:
+ line = line.split("#", 1)[0].strip()
+ if not line:
+ continue
+ if m := re.fullmatch(r'([0-9]+)\s*=\s*(\S+)\s*"([^"]*)"', line):
+ msg = Message()
+ msg.id = int(m.group(1))
+ msg.name = m.group(2)
+ msg.members = parse_members(env, [], m.group(3))
+ msgs += [msg]
+ prev = msg
+ elif m := re.fullmatch(r'(\S+)\s*=\s*"([^"]*)"', line):
+ struct = Struct()
+ struct.name = m.group(1)
+ struct.members = parse_members(env, [], m.group(2))
+ env[struct.name] = struct
+ prev = struct
+ elif m := re.fullmatch(r'"([^"]*)"', line):
+ if not prev:
+ raise SyntaxError(
+ "a continuation line must come after a struct line"
+ )
+ prev.members = parse_members(env, prev.members, line.strip('"'))
+ else:
+ raise SyntaxError(f"invalid line {repr(line)}")
+ structs = [x for x in env.values() if isinstance(x, Struct)]
+ return structs, msgs
+
+
+def c_typename(typ: Atom | Struct | List | Message) -> str:
+ match typ:
+ case Atom.u8:
+ return "uint8_t"
+ case Atom.u16:
+ return "uint16_t"
+ case Atom.u32:
+ return "uint32_t"
+ case Atom.u64:
+ return "uint64_t"
+ case Struct():
+ return "struct v9fs_" + typ.name
+ case Message():
+ return "struct v9fs_msg_" + typ.name
+ case List():
+ return c_typename(typ.typ) + "*"
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
+def gen_h(structs: list[Struct], msgs: list[Message]) -> str:
+ ret = ""
+ ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n"
+ ret += "\n"
+ ret += "#ifndef _NET9P_DEFS_H_\n"
+ ret += "#define _NET9P_DEFS_H_\n"
+ ret += "\n"
+
+ for struct in structs:
+ ret += c_typename(struct) + " {\n"
+ typewidth = max(len(c_typename(member.typ)) for member in struct.members)
+ for member in struct.members:
+ ret += f"\t{c_typename(member.typ).ljust(typewidth)} {member.name};\n"
+ ret += "};\n"
+ ret += "\n"
+
+ ret += "enum v9fs_msg_type {\n"
+ namewidth = max(len(msg.name) for msg in msgs)
+ for msg in msgs:
+ ret += f"\tV9FS_TYP_{msg.name.ljust(namewidth)} = {msg.id},\n"
+ ret += "};\n"
+
+ for msg in msgs:
+ if not msg.members:
+ ret += c_typename(msg) + " {};\n"
+ else:
+ ret += c_typename(msg) + " {\n"
+ typewidth = max(len(c_typename(member.typ)) for member in msg.members)
+ for member in msg.members:
+ ret += f"\t{c_typename(member.typ).ljust(typewidth)} {member.name};\n"
+ ret += "};\n"
+ ret += "\n"
+
+ ret += "#endif /* _NET9P_DEFS_H_ */\n"
+ return ret
+
+
+c_atom_funcs = """
+static inline uint16_t unmarshal_u16le(uint8_t *bytes) {
+ return (((uint16_t)(bytes[0])) << 0)
+ | (((uint16_t)(bytes[1])) << 8)
+ ;
+}
+static inline uint32_t unmarshal_u32le(uint8_t *bytes) {
+ return (((uint16_t)(bytes[0])) << 0)
+ | (((uint16_t)(bytes[1])) << 8)
+ | (((uint16_t)(bytes[2])) << 16)
+ | (((uint16_t)(bytes[3])) << 24)
+ ;
+}
+static inline uint64_t unmarshal_u64le(uint8_t *bytes) {
+ return (((uint16_t)(bytes[0])) << 0)
+ | (((uint16_t)(bytes[1])) << 8)
+ | (((uint16_t)(bytes[2])) << 16)
+ | (((uint16_t)(bytes[3])) << 24)
+ | (((uint16_t)(bytes[4])) << 32)
+ | (((uint16_t)(bytes[5])) << 40)
+ | (((uint16_t)(bytes[6])) << 48)
+ | (((uint16_t)(bytes[7])) << 56)
+ ;
+}
+
+static inline void marshal_u16le(val uint16_t, uint8_t *bytes) {
+ bytes[0] = (uint8_t)((val >> 0) & 0xFF);
+ bytes[1] = (uint8_t)((val >> 8) & 0xFF);
+}
+static inline void marshal_u32le(val uint32_t, uint8_t *bytes) {
+ bytes[0] = (uint8_t)((val >> 0) & 0xFF);
+ bytes[1] = (uint8_t)((val >> 8) & 0xFF);
+ bytes[2] = (uint8_t)((val >> 16) & 0xFF);
+ bytes[3] = (uint8_t)((val >> 24) & 0xFF);
+}
+static inline void marshal_u64le(val uint64_t, uint8_t *bytes) {
+ bytes[0] = (uint8_t)((val >> 0) & 0xFF);
+ bytes[1] = (uint8_t)((val >> 8) & 0xFF);
+ bytes[2] = (uint8_t)((val >> 16) & 0xFF);
+ bytes[3] = (uint8_t)((val >> 24) & 0xFF);
+ bytes[4] = (uint8_t)((val >> 32) & 0xFF);
+ bytes[5] = (uint8_t)((val >> 40) & 0xFF);
+ bytes[6] = (uint8_t)((val >> 48) & 0xFF);
+ bytes[7] = (uint8_t)((val >> 56) & 0xFF);
+}
+"""
+
+
+def static_net_size(typ: Atom | Struct | List | Message) -> int | None:
+ match typ:
+ case Atom.u8:
+ return 1
+ case Atom.u16:
+ return 2
+ case Atom.u32:
+ return 4
+ case Atom.u64:
+ return 8
+ case Struct() | Message():
+ size = 0
+ for member in typ.members:
+ msize = static_net_size(member.typ)
+ if msize is None:
+ return None
+ size += msize
+ return size
+ case List():
+ return None
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
+def gen_c_check_net_len(msg: Message) -> str:
+ ret: str = ""
+ if (sz := static_net_size(msg)) is not None:
+ ret += f"\tif (net_len != {sz})\n"
+ ret += "\t\treturn -EINVAL;\n"
+ return ret
+
+ ret += f"\tuint64_t net_offset = 0;\n"
+ static_acc = 0
+ def _gen_c_check_net_len(prefix: str, struct: Struct | Message) -> str:
+ nonlocal static_acc
+ ret: str = ""
+
+ prev_size: int = 0
+ for member in struct.members:
+ if (sz := static_net_size(member.typ)) is not None:
+ static_acc += sz
+ prev_size = sz
+ elif isinstance(member.typ, Struct):
+ ret += _gen_c_check_net_len(prefix, member.typ)
+ elif isinstance(member.typ, List):
+ if static_acc:
+ ret += f"{prefix}net_offset += {static_acc};\n"
+ static_acc = 0
+ ret += f"{prefix}if ((uint64_t)net_len < net_offset)\n"
+ ret += f"{prefix}\treturn -EINVAL;\n"
+ if (sz := static_net_size(member.typ.typ)) is not None:
+ ret += f"{prefix}net_offset += unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}])*{sz};\n"
+ else:
+ assert isinstance(member.typ.typ, Struct)
+ ret += f"{prefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}]); i < cnt; i++) {{\n"
+ ret += _gen_c_check_net_len(prefix + "\t", member.typ.typ)
+ if static_acc:
+ ret += f"{prefix}net_offset += {static_acc};\n"
+ static_acc = 0
+ ret += f"{prefix}\n"
+ else:
+ raise ValueError(f"not a type: {member.typ.__class__.__name__}")
+
+ return ret
+ ret += _gen_c_check_net_len("\t", msg)
+ if static_acc:
+ ret += f"\tnet_offset += {static_acc};\n"
+ ret += "\tif ((uint64_t)net_len != net_offset)\n"
+ ret += "\t\treturn -EINVAL;\n"
+ return ret
+
+
+
+
+def gen_c(structs: list[Struct], msgs: list[Message]) -> str:
+ ret = ""
+ ret += "/* Generated by ./net9p_defs.gen. DO NOT EDIT! */\n"
+ ret += "\n"
+ ret += "#include <stdint.h> /* for size_t, uint{n}_t */\n"
+ ret += "#include <stdlib.h> /* for malloc() */\n"
+ ret += "\n"
+ ret += c_atom_funcs
+ ret += "\n"
+
+ for msg in msgs:
+ ret += f"int unmarshal_msg_{msg.name}(uint32_t net_len, uint8_t *net_bytes, {c_typename(msg)} **ret) {{\n"
+ ret += gen_c_check_net_len(msg)
+ ret += "\n"
+ ret += "\tTODO;\n"
+ ret += "}\n"
+ ret += "\n"
+ return ret
+
+
+if __name__ == "__main__":
+ structs, msgs = parse_file("net9p_defs.txt")
+ print(gen_h(structs, msgs))
+ print(gen_c(structs, msgs))