#!/usr/bin/env python # lib9p/types.gen - Generate C marshalers/unmarshalers for .txt files # defining 9P protocol variants. # # Copyright (C) 2024 Luke T. Shumaker # SPDX-Licence-Identifier: AGPL-3.0-or-later import enum import os.path import re from typing import Callable, Sequence # This strives to be "general-purpose" in that it just acts on the # *.txt inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # Parse *.txt ################################################################## class Atom(enum.Enum): u8 = 1 u16 = 2 u32 = 4 u64 = 8 @property def name(self) -> str: return str(self.value) @property def static_size(self) -> int: return self.value class Bitfield: name: str bits: list[str] aliases: dict[str, str] @property def static_size(self) -> int | None: return int((len(self.bits) + 7) / 8) # `msgid/structname = "member1 member2..."` # `structname = "member1 member2..."` # `structname += "member1 member2..."` class Struct: msgid: int | None = None msgver: set[str] name: str members: list["Member"] def __init__(self) -> None: self.msgver = set() @property def static_size(self) -> int | None: size = 0 for member in self.members: msize = member.static_size if msize is None: return None size += msize return size # `cnt*(name[typ])` # the `cnt*(...)` wrapper is optional class Member: cnt: str | None = None name: str typ: Atom | Bitfield | Struct ver: set[str] @property def static_size(self) -> int | None: if self.cnt: return None return self.typ.static_size re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" re_memberspec = ( f"(?:(?P{re_membername})\\*\\()?(?P{re_membername})\\[(?P.*)\\]\\)?" ) def parse_members( ver: str, env: dict[str, Atom | Bitfield | Struct], existing: list[Member], specs: str, ) -> list[Member]: ret = existing for spec in specs.split(): m = re.fullmatch(re_memberspec, spec) if not m: raise SyntaxError(f"invalid member spec {repr(spec)}") member = Member() member.ver = {ver} member.name = m.group("name") if any(x.name == member.name for x in ret): raise ValueError(f"duplicate member name {repr(member.name)}") if m.group("typ") not in env: raise NameError(f"Unknown type {repr(m.group(2))}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): if len(ret) == 0 or ret[-1].name != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") if not isinstance(ret[-1].typ, Atom): raise ValueError(f"list count must be an integer type: {repr(cnt)}") member.cnt = cnt ret += [member] return ret re_version = r'version\s+"(?P[^"]+)"' re_import = r"from\s+(?P\S+)\s+import\s+(?P\S+(?:\s*,\s*\S+)*)\s*" re_structspec = ( r'(?:(?P[0-9]+)/)?(?P\S+)\s*(?P\+?=)\s*"(?P[^"]*)"' ) re_structspec_cont = r'\s+"(?P[^"]*)"' re_bitfieldspec = r"bitfield\s+(?P\S+)\s+(?P[0-9]+)" re_bitfieldspec_bit = r"(?:\s+|(?P\S+)\s*\+=\s*)(?P[0-9]+)/(?P\S+)" re_bitfieldspec_alias = r"(?:\s+|(?P\S+)\s*\+=\s*)(?P\S+)=(?P.*)" def parse_file( filename: str, get_include: Callable[[str], tuple[str, list[Bitfield | Struct]]] ) -> tuple[str, list[Bitfield | Struct]]: version: str | None = None env: dict[str, Atom | Bitfield | Struct] = { "1": Atom.u8, "2": Atom.u16, "4": Atom.u32, "8": Atom.u64, } with open(filename, "r") as fh: prev: Struct | Bitfield | None = None for line in fh: line = line.split("#", 1)[0].rstrip() if not line: continue if m := re.fullmatch(re_version, line): if version: raise SyntaxError("must have exactly 1 version line") version = m.group("version") elif m := re.fullmatch(re_import, line): if not version: raise SyntaxError("must have exactly 1 version line") other_version, other_typs = get_include(m.group("file")) for symname in m.group("syms").split(sep=","): symname = symname.strip() for typ in other_typs: if typ.name == symname or symname == "*": if isinstance(typ, Struct): if typ.msgid: typ.msgver.add(version) for member in typ.members: if other_version in member.ver: member.ver.add(version) env[typ.name] = typ elif m := re.fullmatch(re_structspec, line): if not version: raise SyntaxError("must have exactly 1 version line") if m.group("op") == "+=" and m.group("msgid"): raise SyntaxError("cannot += to a message that is not yet defined") match m.group("op"): case "=": struct = Struct() if m.group("msgid"): struct.msgid = int(m.group("msgid")) struct.msgver.add(version) struct.name = m.group("name") struct.members = parse_members( version, env, [], m.group("members") ) env[struct.name] = struct prev = struct case "+=": if m.group("name") not in env: raise NameError(f"Unknown type {repr(m.group('name'))}") _struct = env[m.group("name")] if not isinstance(_struct, Struct): raise NameError( f"Type {repr(_struct.name)} is not a struct" ) struct = _struct struct.members = parse_members( version, env, struct.members, m.group("members") ) prev = struct elif m := re.fullmatch(re_structspec_cont, line): if not isinstance(prev, Struct): raise SyntaxError( "struct-continuation line must come after a struct line" ) assert version prev.members = parse_members( version, env, prev.members, m.group("members") ) elif m := re.fullmatch(re_bitfieldspec, line): bf = Bitfield() bf.name = m.group("name") bf.bits = int(m.group("size")) * [""] bf.aliases = {} env[bf.name] = bf prev = bf elif m := re.fullmatch(re_bitfieldspec_bit, line): if m.group("bitfield"): if m.group("bitfield") not in env: raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") _bf = env[m.group("bitfield")] if not isinstance(_bf, Bitfield): raise NameError(f"Type {repr(_bf.name)} is not a bitfield") bf = _bf else: if not isinstance(prev, Bitfield): raise SyntaxError( "bitfield-continuation line must come after a bitfield line" ) bf = prev bit = int(m.group("bit")) name = m.group("name") if bit < 0 or bit >= len(bf.bits): raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") if bf.bits[bit]: raise ValueError(f"{bf.name}: bit {bit} already assigned") if name in bf.aliases: raise ValueError(f"{bf.name}: name {name} already assigned") bf.bits[bit] = name bf.aliases[name] = "" elif m := re.fullmatch(re_bitfieldspec_alias, line): if m.group("bitfield"): if m.group("bitfield") not in env: raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") _bf = env[m.group("bitfield")] if not isinstance(_bf, Bitfield): raise NameError(f"Type {repr(_bf.name)} is not a bitfield") bf = _bf else: if not isinstance(prev, Bitfield): raise SyntaxError( "bitfield-continuation line must come after a bitfield line" ) bf = prev name = m.group("name") val = m.group("val") if name in bf.aliases: raise ValueError(f"{bf.name}: name {name} already assigned") bf.aliases[name] = val else: raise SyntaxError(f"invalid line {repr(line)}") if not version: raise SyntaxError("must have exactly 1 version line") typs = [x for x in env.values() if not isinstance(x, Atom)] return version, typs # Generate C ################################################################### def c_typename(idprefix: str, typ: Atom | Bitfield | Struct) -> str: match typ: case Atom(): return f"uint{typ.value*8}_t" case Bitfield(): return f"{idprefix}{typ.name}_t" case Struct(): if typ.msgid is not None: return f"struct {idprefix}msg_{typ.name}" return f"struct {idprefix}{typ.name}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") def c_verenum(idprefix: str, ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" def c_vercomment(versions: set[str]) -> str | None: if "9P2000" in versions: return None return "/* " + (", ".join(sorted(versions))) + " */" def c_vercond(idprefix: str, versions: set[str]) -> str: if len(versions) == 1: return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})" return ( "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )" ) def just_structs_all(typs: list[Bitfield | Struct]) -> Sequence[Struct]: return list(typ for typ in typs if isinstance(typ, Struct)) def just_structs_nonmsg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: return list(typ for typ in typs if isinstance(typ, Struct) and typ.msgid is None) def just_structs_msg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: return list( typ for typ in typs if isinstance(typ, Struct) and typ.msgid is not None ) def just_bitfields(typs: list[Bitfield | Struct]) -> Sequence[Bitfield]: return list(typ for typ in typs if isinstance(typ, Bitfield)) def gen_h(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: guard = "_LIB9P__TYPES_H_" ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef {guard} #define {guard} #include /* for uint{{n}}_t types */ """ ret += f""" /* versions *******************************************************************/ enum {idprefix}version {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: ret += f"\t{c_verenum(idprefix, ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += f"\t{c_verenum(idprefix, 'NUM')},\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" ret += """ /* non-message types **********************************************************/ """ for bf in just_bitfields(typs): ret += "\n" ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n" vals = dict([ *reversed([((k or f"_UNUSED_{v}"), f"1<<{v}") for (v, k) in enumerate(bf.bits)]), *[(k, v) for (k, v) in bf.aliases.items() if v], ]) namewidth = max(len(name) for name in vals) for name, val in vals.items(): ret += f"#define {idprefix.upper()}{bf.name.upper()}_{name.ljust(namewidth)} (({c_typename(idprefix, bf)})({val}))\n" for struct in just_structs_nonmsg(typs): all_the_same = len(struct.members) == 0 or all( m.ver == struct.members[0].ver for m in struct.members ) typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members) if not all_the_same: namewidth = max(len(m.name) for m in struct.members) ret += "\n" ret += c_typename(idprefix, struct) + " {\n" for member in struct.members: ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment ret += "\n" ret += "};\n" ret += """ /* messages *******************************************************************/ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" namewidth = max(len(msg.name) for msg in just_structs_msg(typs)) for msg in just_structs_msg(typs): ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid}," if comment := c_vercomment(msg.msgver): ret += " " + comment ret += "\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" for msg in just_structs_msg(typs): ret += "\n" if comment := c_vercomment(msg.msgver): ret += comment + "\n" ret += c_typename(idprefix, msg) + " {" if not msg.members: ret += "};\n" continue ret += "\n" all_the_same = len(msg.members) == 0 or all( m.ver == msg.members[0].ver for m in msg.members ) typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) if not all_the_same: namewidth = max(len(m.name) for m in msg.members) for member in msg.members: ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment ret += "\n" ret += "};\n" ret += "\n" ret += f"#endif /* {guard} */\n" return ret def gen_c(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #include #include #include /* for size_t */ #include /* for PRI* macros */ #include /* for memset() */ #include #include "internal.h" """ def used(arg: str) -> str: return arg def unused(arg: str) -> str: return f"UNUSED({arg})" # strings ################################################################## ret += f""" /* strings ********************************************************************/ static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n' ret += "};\n" ret += f""" const char *{idprefix}version_str(enum {idprefix}version ver) {{ assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')}); return version_strs[ver]; }} static const char *msg_type_strs[0x100] = {{ """ id2name: dict[int, str] = {} for msg in just_structs_msg(typs): assert msg.msgid id2name[msg.msgid] = msg.name for n in range(0, 0x100): ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) ret += "};\n" ret += f""" const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ assert(0 <= typ && typ <= 0xFF); return msg_type_strs[typ]; }} """ # validate_* ############################################################### ret += """ /* validate_* *****************************************************************/ static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) /* If needed-net-size overflowed uint32_t, then * there's no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); if (ctx->net_offset > ctx->net_size) return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) /* If needed-host-size overflowed size_t, then there's * no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { for (size_t i = 0; i < cnt; i++) if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) return true; return false; } #define validate_1(ctx) _validate_size_net(ctx, 1) #define validate_2(ctx) _validate_size_net(ctx, 2) #define validate_4(ctx) _validate_size_net(ctx, 4) #define validate_8(ctx) _validate_size_net(ctx, 8) """ for struct in just_structs_all(typs): inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN" argfn = used if struct.members else unused ret += "\n" ret += f"static{inline} bool validate_{struct.name}(struct _validate_ctx *{argfn('ctx')}) {{" if len(struct.members) == 0: ret += "\n\treturn false;\n" ret += "}\n" continue if struct.name == "d": # SPECIAL # Optimize... maybe the compiler could figure out to do # this, but let's make it obvious. ret += "\n" ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_4(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" ret += "}\n" continue if struct.name == "s": # SPECIAL # Add an extra nul-byte on the host, and validate UTF-8 # (also, similar optimization to "d"). ret += "\n" ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_2(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" ret += "\t\treturn true;\n" ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' ret += "\treturn false;\n" ret += "}\n" continue prefix0 = "\treturn " prefix1 = "\t || " struct_versions = struct.members[0].ver prefix = prefix0 prev_size: int | None = None for member in struct.members: ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt is not None: assert prev_size ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" else: ret += f"validate_{member.typ.name}(ctx)" if member.ver != struct_versions: ret += " )" prefix = prefix1 prev_size = member.static_size ret += ";\n}\n" # unmarshal_* ############################################################## ret += """ /* unmarshal_* ****************************************************************/ static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 1; } static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 2; } static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 4; } static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 8; } """ for struct in just_structs_all(typs): inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN" argfn = used if struct.members else unused ret += "\n" ret += f"static{inline} void unmarshal_{struct.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *out) {{\n" ret += "\tmemset(out, 0, sizeof(*out));\n" if struct.members: struct_versions = struct.members[0].ver for member in struct.members: ret += "\t" prefix = "\t" if member.ver != struct_versions: ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " prefix = "\t\t" if member.cnt: if member.ver != struct_versions: ret += "{\n" ret += f"{prefix}out->{member.name} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" if member.ver != struct_versions: ret += "\t}\n" else: ret += f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" ret += "}\n" # marshal_* ################################################################ ret += """ /* marshal_* ******************************************************************/ static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), ctx->ctx->max_msg_size); return true; } static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) return _marshal_too_large(ctx); ctx->net_bytes[ctx->net_offset] = *val; ctx->net_offset += 1; return false; } static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) return _marshal_too_large(ctx); encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 2; return false; } static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) return true; encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 4; return false; } static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) return true; encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 8; return false; } """ for struct in just_structs_all(typs): inline = " ALWAYS_INLINE" if struct.msgid is None else " FLATTEN" argfn = used if struct.members else unused ret += "\n" ret += f"static{inline} bool marshal_{struct.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *{argfn('val')}) {{" if len(struct.members) == 0: ret += "\n\treturn false;\n" ret += "}\n" continue prefix0 = "\treturn " prefix1 = "\t || " prefix2 = "\t " struct_versions = struct.members[0].ver prefix = prefix0 for member in struct.members: ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt: ret += "({" ret += f"\n{prefix2}\tbool err = false;" ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" ret += f"\n{prefix2}\terr;" ret += f"\n{prefix2}}})" else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" if member.ver != struct_versions: ret += " )" prefix = prefix1 ret += ";\n}\n" # vtables ################################################################## ret += f""" /* vtables ********************************************************************/ #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ .basesize = sizeof(struct {idprefix}msg_##typ), \\ .validate = validate_##typ, \\ .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ .marshal = (_marshal_fn_t)marshal_##typ, \\ }} struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{ """ ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n" for msg in just_structs_msg(typs): if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL ret += f"\t\t_MSG({msg.name}),\n" ret += "\t}},\n" for ver in sorted(versions): ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n" for msg in just_structs_msg(typs): if ver not in msg.msgver: continue ret += f"\t\t_MSG({msg.name}),\n" ret += "\t}},\n" ret += "};\n" ############################################################################ return ret ################################################################################ class Parser: cache: dict[str, tuple[str, list[Bitfield | Struct]]] = {} def parse_file(self, filename: str) -> tuple[str, list[Bitfield | Struct]]: filename = os.path.normpath(filename) if filename not in self.cache: def get_include(other_filename: str) -> tuple[str, list[Bitfield | Struct]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] def all(self) -> tuple[set[str], list[Bitfield | Struct]]: ret_versions: set[str] = set() ret_typs: dict[str, Bitfield | Struct] = {} for version, typs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {repr(version)}") ret_versions.add(version) for typ in typs: if typ.name in ret_typs: if typ != ret_typs[typ.name]: raise ValueError(f"duplicate type name {repr(typ.name)}") else: ret_typs[typ.name] = typ return ret_versions, list(ret_typs.values()) if __name__ == "__main__": import sys if len(sys.argv) < 2: raise ValueError("requires at least 1 .txt filename") parser = Parser() for txtname in sys.argv[1:]: parser.parse_file(txtname) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/_types.h"), "w") as fh: fh.write(gen_h("lib9p_", versions, typs)) with open(os.path.join(outdir, "types.c"), "w") as fh: fh.write(gen_c("lib9p_", versions, typs))