#!/usr/bin/env python # lib9p/types.gen - Generate C marshalers/unmarshalers for .txt files # defining 9P protocol variants. # # Copyright (C) 2024 Luke T. Shumaker # SPDX-Licence-Identifier: AGPL-3.0-or-later import enum import os.path import re from typing import Callable # This strives to be "general-purpose" in that it just acts on the # *.txt inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # Parse *.txt ################################################################## class Atom(enum.Enum): u8 = 1 u16 = 2 u32 = 4 u64 = 8 @property def name(self) -> str: return str(self.value) @property def static_size(self) -> int: return self.value # `msgid/structname = "member1 member2..."` # `structname = "member1 member2..."` # `structname += "member1 member2..."` class Struct: msgid: int | None = None msgver: set[str] name: str members: list["Member"] def __init__(self) -> None: self.msgver = set() @property def static_size(self) -> int | None: size = 0 for member in self.members: msize = member.static_size if msize is None: return None size += msize return size # `cnt*(name[typ])` # the `cnt*(...)` wrapper is optional class Member: cnt: str | None = None name: str typ: Atom | Struct ver: set[str] @property def static_size(self) -> int | None: if self.cnt: return None return self.typ.static_size re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" re_memberspec = ( f"(?:(?P{re_membername})\\*\\()?(?P{re_membername})\\[(?P.*)\\]\\)?" ) def parse_members( ver: str, env: dict[str, Atom | Struct], existing: list[Member], specs: str ) -> list[Member]: ret = existing for spec in specs.split(): m = re.fullmatch(re_memberspec, spec) if not m: raise SyntaxError(f"invalid member spec {repr(spec)}") member = Member() member.ver = {ver} member.name = m.group("name") if any(x.name == member.name for x in ret): raise ValueError(f"duplicate member name {repr(member.name)}") if m.group("typ") not in env: raise NameError(f"Unknown type {repr(m.group(2))}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): if len(ret) == 0 or ret[-1].name != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") if not isinstance(ret[-1].typ, Atom): raise ValueError(f"list count must be an integer type: {repr(cnt)}") member.cnt = cnt ret += [member] return ret re_version = r'version\s+"(?P[^"]+)"' re_import = r"from\s+(?P\S+)\s+import\s+(?P\S+(?:\s*,\s*\S+)*)\s*" re_structspec = ( r'(?:(?P[0-9]+)/)?(?P\S+)\s*(?P\+?=)\s*"(?P[^"]*)"' ) re_structspec_cont = r'"(?P[^"]*)"' def parse_file( filename: str, get_include: Callable[[str], tuple[str, list[Struct]]] ) -> tuple[str, list[Struct]]: version: str | None = None env: dict[str, Atom | Struct] = { "1": Atom.u8, "2": Atom.u16, "4": Atom.u32, "8": Atom.u64, } with open(filename, "r") as fh: prev: Struct | None = None for line in fh: line = line.split("#", 1)[0].strip() if not line: continue if m := re.fullmatch(re_version, line): if version: raise SyntaxError("must have exactly 1 version line") version = m.group("version") elif m := re.fullmatch(re_import, line): if not version: raise SyntaxError("must have exactly 1 version line") other_version, other_structs = get_include(m.group("file")) for symname in m.group("syms").split(sep=","): symname = symname.strip() for struct in other_structs: if struct.name == symname or symname == "*": if struct.msgid: struct.msgver.add(version) for member in struct.members: if other_version in member.ver: member.ver.add(version) env[struct.name] = struct elif m := re.fullmatch(re_structspec, line): if not version: raise SyntaxError("must have exactly 1 version line") if m.group("op") == "+=" and m.group("msgid"): raise SyntaxError("cannot += to a message that is not yet defined") match m.group("op"): case "=": struct = Struct() if m.group("msgid"): struct.msgid = int(m.group("msgid")) struct.msgver.add(version) struct.name = m.group("name") struct.members = parse_members( version, env, [], m.group("members") ) env[struct.name] = struct prev = struct case "+=": if m.group("name") not in env: raise NameError(f"Unknown type {repr(m.group('name'))}") _struct = env[m.group("name")] if not isinstance(_struct, Struct): raise NameError( f"Type {repr(m.group('name'))} is not a struct" ) struct = _struct struct.members = parse_members( version, env, struct.members, m.group("members") ) prev = struct elif m := re.fullmatch(re_structspec_cont, line): if not prev: raise SyntaxError("continuation line must come after a struct line") assert version prev.members = parse_members( version, env, prev.members, m.group("members") ) else: raise SyntaxError(f"invalid line {repr(line)}") if not version: raise SyntaxError("must have exactly 1 version line") structs = [x for x in env.values() if isinstance(x, Struct)] return version, structs # Generate C ################################################################### def c_typename(idprefix: str, typ: Atom | Struct) -> str: match typ: case Atom(): return f"uint{typ.value*8}_t" case Struct(): if typ.msgid is not None: return f"struct {idprefix}msg_{typ.name}" return f"struct {idprefix}{typ.name}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") def c_verenum(idprefix: str, ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" def c_vercomment(versions: set[str]) -> str | None: if "9P2000" in versions: return None return "/* " + (", ".join(sorted(versions))) + " */" def c_vercond(idprefix: str, versions: set[str]) -> str: if len(versions) == 1: return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})" return ( "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )" ) def gen_h(idprefix: str, versions: set[str], structs: list[Struct]) -> str: guard = "_LIB9P__TYPES_H_" ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef {guard} #define {guard} #include /* for uint{{n}}_t types */ """ ret += f""" /* versions *******************************************************************/ enum {idprefix}version {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: ret += f"\t{c_verenum(idprefix, ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += f"\t{c_verenum(idprefix, 'NUM')},\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" ret += """ /* non-message structs ********************************************************/ """ for struct in structs: if struct.msgid is not None: continue all_the_same = len(struct.members) == 0 or all( m.ver == struct.members[0].ver for m in struct.members ) typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members) if not all_the_same: namewidth = max(len(m.name) for m in struct.members) ret += "\n" ret += c_typename(idprefix, struct) + " {\n" for member in struct.members: ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment ret += "\n" ret += "};\n" ret += """ /* messages *******************************************************************/ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" namewidth = max(len(msg.name) for msg in structs if msg.msgid is not None) for msg in structs: if msg.msgid is None: continue ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid}," if comment := c_vercomment(msg.msgver): ret += " " + comment ret += "\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" for msg in structs: if msg.msgid is None: continue ret += "\n" if comment := c_vercomment(msg.msgver): ret += comment + "\n" ret += c_typename(idprefix, msg) + " {" if not msg.members: ret += "};\n" continue ret += "\n" all_the_same = len(msg.members) == 0 or all( m.ver == msg.members[0].ver for m in msg.members ) typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) if not all_the_same: namewidth = max(len(m.name) for m in msg.members) for member in msg.members: ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment ret += "\n" ret += "};\n" ret += "\n" ret += f"#endif /* {guard} */\n" return ret def gen_c(idprefix: str, versions: set[str], structs: list[Struct]) -> str: ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #include #include #include /* for size_t */ #include /* for PRI* macros */ #include /* for memset() */ #include #include "internal.h" """ def used(arg: str) -> str: return arg def unused(arg: str) -> str: return f"UNUSED({arg})" # strings ################################################################## ret += f""" /* strings ********************************************************************/ static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n' ret += "};\n" ret += f""" const char *{idprefix}version_str(enum {idprefix}version ver) {{ assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')}); return version_strs[ver]; }} static const char *msg_type_strs[0x100] = {{ """ id2name: dict[int, str] = {} for msg in structs: if msg.msgid is not None: id2name[msg.msgid] = msg.name for n in range(0, 0x100): ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) ret += "};\n" ret += f""" const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ assert(0 <= typ && typ <= 0xFF); return msg_type_strs[typ]; }} """ # checksize_* ############################################################## ret += """ /* checksize_* (internals of unmarshal_size()) ********************************/ static inline bool _checksize_net(struct _checksize_ctx *ctx, uint32_t n) { if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) /* If needed-net-size overflowed uint32_t, then * there's no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); if (ctx->net_offset > ctx->net_size) return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static inline bool _checksize_host(struct _checksize_ctx *ctx, size_t n) { if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) /* If needed-host-size overflowed size_t, then there's * no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static inline bool _checksize_list(struct _checksize_ctx *ctx, size_t cnt, _checksize_fn_t item_fn, size_t item_host_size) { for (size_t i = 0; i < cnt; i++) if (_checksize_host(ctx, item_host_size) || item_fn(ctx)) return true; return false; } #define checksize_1(ctx) _checksize_net(ctx, 1) #define checksize_2(ctx) _checksize_net(ctx, 2) #define checksize_4(ctx) _checksize_net(ctx, 4) #define checksize_8(ctx) _checksize_net(ctx, 8) """ for struct in structs: inline = " inline" if struct.msgid is None else "" argfn = used if struct.members else unused ret += "\n" ret += f"static{inline} bool checksize_{struct.name}(struct _checksize_ctx *{argfn('ctx')}) {{" if len(struct.members) == 0: ret += "\n\treturn false;\n" ret += "}\n" continue if struct.name == "d": # SPECIAL # Optimize... maybe the compiler could figure out to do # this, but let's make it obvious. ret += "\n" ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (checksize_4(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" ret += "\treturn _checksize_net(ctx, len) || _checksize_host(ctx, len);\n" ret += "}\n" continue if struct.name == "s": # SPECIAL # Add an extra nul-byte on the host, and validate UTF-8 # (also, similar optimization to "d"). ret += "\n" ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (checksize_2(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" ret += "\tif (_checksize_net(ctx, len) || _checksize_host(ctx, ((size_t)len)+1))\n" ret += "\t\treturn true;\n" ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' ret += "\treturn false;\n" ret += "}\n" continue prefix0 = "\treturn " prefix1 = "\t || " struct_versions = struct.members[0].ver prefix = prefix0 prev_size: int | None = None for member in struct.members: ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt is not None: assert prev_size ret += f"_checksize_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), checksize_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" else: ret += f"checksize_{member.typ.name}(ctx)" if member.ver != struct_versions: ret += " )" prefix = prefix1 prev_size = member.static_size ret += ";\n}\n" # unmarshal_* ############################################################## ret += """ /* unmarshal_* ****************************************************************/ static inline void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 1; } static inline void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 2; } static inline void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 4; } static inline void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 8; } """ for struct in structs: inline = " inline" if struct.msgid is None else "" argfn = used if struct.members else unused ret += "\n" ret += f"static{inline} void unmarshal_{struct.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *out) {{\n" ret += "\tmemset(out, 0, sizeof(*out));\n" if struct.members: struct_versions = struct.members[0].ver for member in struct.members: ret += "\t" prefix = "\t" if member.ver != struct_versions: ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " prefix = "\t\t" if member.cnt: if member.ver != struct_versions: ret += "{\n" ret += f"{prefix}out->{member.name} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" if member.ver != struct_versions: ret += "\t}\n" else: ret += f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" ret += "}\n" # marshal_* ################################################################ ret += """ /* marshal_* ******************************************************************/ static inline bool _marshal_too_large(struct _marshal_ctx *ctx) { lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), ctx->ctx->max_msg_size); return true; } static inline bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) return _marshal_too_large(ctx); ctx->net_bytes[ctx->net_offset] = *val; ctx->net_offset += 1; return false; } static inline bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) return _marshal_too_large(ctx); encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 2; return false; } static inline bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) return true; encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 4; return false; } static inline bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) return true; encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 8; return false; } """ for struct in structs: inline = " inline" if struct.msgid is None else "" argfn = used if struct.members else unused ret += "\n" ret += f"static{inline} bool marshal_{struct.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *{argfn('val')}) {{" if len(struct.members) == 0: ret += "\n\treturn false;\n" ret += "}\n" continue prefix0 = "\treturn " prefix1 = "\t || " prefix2 = "\t " struct_versions = struct.members[0].ver prefix = prefix0 for member in struct.members: ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt: ret += "({" ret += f"\n{prefix2}\tbool err = false;" ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" ret += f"\n{prefix2}\terr;" ret += f"\n{prefix2}}})" else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" if member.ver != struct_versions: ret += " )" prefix = prefix1 ret += ";\n}\n" # vtables ################################################################## ret += f""" /* vtables ********************************************************************/ #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ .unmarshal_basesize = sizeof(struct {idprefix}msg_##typ), \\ .unmarshal_extrasize = checksize_##typ, \\ .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ .marshal = (_marshal_fn_t)marshal_##typ, \\ }} struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{ """ ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n" for msg in structs: if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL ret += f"\t\t_MSG({msg.name}),\n" ret += "\t}},\n" for ver in sorted(versions): ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n" for msg in structs: if ver not in msg.msgver: continue ret += f"\t\t_MSG({msg.name}),\n" ret += "\t}},\n" ret += "};\n" ############################################################################ return ret ################################################################################ class Parser: cache: dict[str, tuple[str, list[Struct]]] = {} def parse_file(self, filename: str) -> tuple[str, list[Struct]]: filename = os.path.normpath(filename) if filename not in self.cache: def get_include(other_filename: str) -> tuple[str, list[Struct]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] def all(self) -> tuple[set[str], list[Struct]]: ret_versions: set[str] = set() ret_structs: dict[str, Struct] = {} for version, structs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {repr(version)}") ret_versions.add(version) for struct in structs: if struct.name in ret_structs: if struct != ret_structs[struct.name]: raise ValueError(f"duplicate struct name {repr(struct.name)}") else: ret_structs[struct.name] = struct return ret_versions, list(ret_structs.values()) if __name__ == "__main__": import sys if len(sys.argv) < 2: raise ValueError("requires at least 1 .txt filename") parser = Parser() for txtname in sys.argv[1:]: parser.parse_file(txtname) versions, structs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/_types.h"), "w") as fh: fh.write(gen_h("lib9p_", versions, structs)) with open(os.path.join(outdir, "types.c"), "w") as fh: fh.write(gen_c("lib9p_", versions, structs))