#!/usr/bin/env python # lib9p/9p.gen - Generate C marshalers/unmarshalers for .txt files # defining 9P protocol variants. # # Copyright (C) 2024 Luke T. Shumaker # SPDX-Licence-Identifier: AGPL-3.0-or-later import enum import os.path import re from typing import Callable, Sequence # This strives to be "general-purpose" in that it just acts on the # *.txt inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # Parse *.txt ################################################################## class Atom(enum.Enum): u8 = 1 u16 = 2 u32 = 4 u64 = 8 @property def name(self) -> str: return str(self.value) @property def static_size(self) -> int: return self.value class BitfieldVal: name: str val: str ver: set[str] def __init__(self) -> None: self.ver = set() class Bitfield: name: str bits: list[str] names: dict[str, BitfieldVal] @property def static_size(self) -> int: return int((len(self.bits) + 7) / 8) def bitname_is_valid(self, bitname: str, ver: str | None = None) -> bool: assert bitname in self.bits if not bitname: return False if bitname.startswith("_"): return False if ver and (ver not in self.names[bitname].ver): return False return True # `msgid/structname = "member1 member2..."` # `structname = "member1 member2..."` # `structname += "member1 member2..."` class Struct: msgid: int | None = None msgver: set[str] name: str members: list["Member"] def __init__(self) -> None: self.msgver = set() @property def static_size(self) -> int | None: size = 0 for member in self.members: msize = member.static_size if msize is None: return None size += msize return size # `cnt*(name[typ])` # the `cnt*(...)` wrapper is optional class Member: cnt: str | None = None name: str typ: Atom | Bitfield | Struct max: int | None = None ver: set[str] @property def static_size(self) -> int | None: if self.cnt: return None return self.typ.static_size re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" re_memberspec = f"(?:(?P{re_membername})\\*\\()?(?P{re_membername})\\[(?P[^,]*)(?:,max=(?P[0-9]+))?\\]\\)?" def parse_members( ver: str, env: dict[str, Atom | Bitfield | Struct], existing: list[Member], specs: str, ) -> list[Member]: ret = existing for spec in specs.split(): m = re.fullmatch(re_memberspec, spec) if not m: raise SyntaxError(f"invalid member spec {repr(spec)}") member = Member() member.ver = {ver} member.name = m.group("name") if any(x.name == member.name for x in ret): raise ValueError(f"duplicate member name {repr(member.name)}") if m.group("typ") not in env: raise NameError(f"Unknown type {repr(m.group(2))}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): if len(ret) == 0 or ret[-1].name != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") if not isinstance(ret[-1].typ, Atom): raise ValueError(f"list count must be an integer type: {repr(cnt)}") member.cnt = cnt if maxstr := m.group("max"): if (not isinstance(member.typ, Atom)) or member.cnt: raise ValueError( f"',max=' may only be specified on a non-repeated atom" ) member.max = int(maxstr) ret += [member] return ret re_version = r'version\s+"(?P[^"]+)"' re_import = r"from\s+(?P\S+)\s+import\s+(?P\S+(?:\s*,\s*\S+)*)\s*" re_structspec = ( r'(?:(?P[0-9]+)/)?(?P\S+)\s*(?P\+?=)\s*"(?P[^"]*)"' ) re_structspec_cont = r'\s+"(?P[^"]*)"' re_bitfieldspec = r"bitfield\s+(?P\S+)\s+(?P[0-9]+)" re_bitfieldspec_bit = r"(?:\s+|(?P\S+)\s*\+=\s*)(?P[0-9]+)/(?P\S+)" re_bitfieldspec_alias = ( r"(?:\s+|(?P\S+)\s*\+=\s*)(?P\S+)\s*=\s*(?P.*)" ) def parse_file( filename: str, get_include: Callable[[str], tuple[str, list[Bitfield | Struct]]] ) -> tuple[str, list[Bitfield | Struct]]: version: str | None = None env: dict[str, Atom | Bitfield | Struct] = { "1": Atom.u8, "2": Atom.u16, "4": Atom.u32, "8": Atom.u64, } with open(filename, "r") as fh: prev: Struct | Bitfield | None = None for line in fh: line = line.split("#", 1)[0].rstrip() if not line: continue if m := re.fullmatch(re_version, line): if version: raise SyntaxError("must have exactly 1 version line") version = m.group("version") elif m := re.fullmatch(re_import, line): if not version: raise SyntaxError("must have exactly 1 version line") other_version, other_typs = get_include(m.group("file")) for symname in m.group("syms").split(sep=","): symname = symname.strip() for typ in other_typs: if typ.name == symname or symname == "*": match typ: case Bitfield(): for val in typ.names.values(): if other_version in val.ver: val.ver.add(version) case Struct(): if typ.msgid: typ.msgver.add(version) for member in typ.members: if other_version in member.ver: member.ver.add(version) env[typ.name] = typ elif m := re.fullmatch(re_structspec, line): if not version: raise SyntaxError("must have exactly 1 version line") if m.group("op") == "+=" and m.group("msgid"): raise SyntaxError("cannot += to a message that is not yet defined") match m.group("op"): case "=": struct = Struct() if m.group("msgid"): struct.msgid = int(m.group("msgid")) struct.msgver.add(version) struct.name = m.group("name") struct.members = parse_members( version, env, [], m.group("members") ) env[struct.name] = struct prev = struct case "+=": if m.group("name") not in env: raise NameError(f"Unknown type {repr(m.group('name'))}") _struct = env[m.group("name")] if not isinstance(_struct, Struct): raise NameError( f"Type {repr(_struct.name)} is not a struct" ) struct = _struct struct.members = parse_members( version, env, struct.members, m.group("members") ) prev = struct elif m := re.fullmatch(re_structspec_cont, line): if not isinstance(prev, Struct): raise SyntaxError( "struct-continuation line must come after a struct line" ) assert version prev.members = parse_members( version, env, prev.members, m.group("members") ) elif m := re.fullmatch(re_bitfieldspec, line): if not version: raise SyntaxError("must have exactly 1 version line") bf = Bitfield() bf.name = m.group("name") bf.bits = int(m.group("size")) * [""] bf.names = {} if len(bf.bits) not in [8, 16, 32, 64]: raise ValueError(f"Bitfield {repr(bf.name)} has an unusual size") env[bf.name] = bf prev = bf elif m := re.fullmatch(re_bitfieldspec_bit, line): if m.group("bitfield"): if m.group("bitfield") not in env: raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") _bf = env[m.group("bitfield")] if not isinstance(_bf, Bitfield): raise NameError(f"Type {repr(_bf.name)} is not a bitfield") bf = _bf prev = bf else: if not isinstance(prev, Bitfield): raise SyntaxError( "bitfield-continuation line must come after a bitfield line" ) bf = prev bit = int(m.group("bit")) name = m.group("name") if bit < 0 or bit >= len(bf.bits): raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") if bf.bits[bit]: raise ValueError(f"{bf.name}: bit {bit} already assigned") if name in bf.names: raise ValueError(f"{bf.name}: name {name} already assigned") bf.bits[bit] = name assert version val = BitfieldVal() val.name = name val.val = f"1<<{bit}" val.ver.add(version) bf.names[name] = val elif m := re.fullmatch(re_bitfieldspec_alias, line): if m.group("bitfield"): if m.group("bitfield") not in env: raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") _bf = env[m.group("bitfield")] if not isinstance(_bf, Bitfield): raise NameError(f"Type {repr(_bf.name)} is not a bitfield") bf = _bf prev = bf else: if not isinstance(prev, Bitfield): raise SyntaxError( "bitfield-continuation line must come after a bitfield line" ) bf = prev name = m.group("name") valstr = m.group("val") if name in bf.names: raise ValueError(f"{bf.name}: name {name} already assigned") assert version val = BitfieldVal() val.name = name val.val = valstr val.ver.add(version) bf.names[name] = val else: raise SyntaxError(f"invalid line {repr(line)}") if not version: raise SyntaxError("must have exactly 1 version line") typs = [x for x in env.values() if not isinstance(x, Atom)] return version, typs # Generate C ################################################################### def c_typename(idprefix: str, typ: Atom | Bitfield | Struct) -> str: match typ: case Atom(): return f"uint{typ.value*8}_t" case Bitfield(): return f"{idprefix}{typ.name}_t" case Struct(): if typ.msgid is not None: return f"struct {idprefix}msg_{typ.name}" return f"struct {idprefix}{typ.name}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") def c_verenum(idprefix: str, ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" def c_vercomment(versions: set[str]) -> str | None: if "9P2000" in versions: return None return "/* " + (", ".join(sorted(versions))) + " */" def c_vercond(idprefix: str, versions: set[str]) -> str: if len(versions) == 1: return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})" return ( "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )" ) def just_structs_all(typs: list[Bitfield | Struct]) -> Sequence[Struct]: return list(typ for typ in typs if isinstance(typ, Struct)) def just_structs_nonmsg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: return list(typ for typ in typs if isinstance(typ, Struct) and typ.msgid is None) def just_structs_msg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: return list( typ for typ in typs if isinstance(typ, Struct) and typ.msgid is not None ) def just_bitfields(typs: list[Bitfield | Struct]) -> Sequence[Bitfield]: return list(typ for typ in typs if isinstance(typ, Bitfield)) def gen_h(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef _LIB9P_9P_H_ # error Do not include directly; include instead #endif #include /* for uint{{n}}_t types */ """ ret += f""" /* versions *******************************************************************/ enum {idprefix}version {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: ret += f"\t{c_verenum(idprefix, ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += f"\t{c_verenum(idprefix, 'NUM')},\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" ret += """ /* non-message types **********************************************************/ """ for bf in just_bitfields(typs): ret += "\n" ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n" names = [ *reversed([bf.bits[n] or f"_UNUSED_{n}" for n in range(0, len(bf.bits))]), *[k for k in bf.names if k not in bf.bits], ] namewidth = max(len(name) for name in names) for name in names: if name.startswith("_"): cname = f"_{idprefix.upper()}{bf.name.upper()}_{name[1:]}" else: cname = f"{idprefix.upper()}{bf.name.upper()}_{name}" if name in bf.names: val = bf.names[name].val else: assert name.startswith("_UNUSED_") val = f"1<<{name[len('_UNUSED_'):]}" ret += f"#define {cname}{' '*(namewidth-len(name))} (({c_typename(idprefix, bf)})({val}))" if (name in bf.names) and (comment := c_vercomment(bf.names[name].ver)): ret += " " + comment ret += "\n" for struct in just_structs_nonmsg(typs): all_the_same = len(struct.members) == 0 or all( m.ver == struct.members[0].ver for m in struct.members ) typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members) if not all_the_same: namewidth = max(len(m.name) for m in struct.members) ret += "\n" ret += c_typename(idprefix, struct) + " {\n" for member in struct.members: if struct.name == "stat" and member.name == "stat_size": # SPECIAL continue ctype = c_typename(idprefix, member.typ) if (struct.name in ["d", "s"]) and member.cnt: # SPECIAL ctype = "char" ret += f"\t{ctype.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment ret += "\n" ret += "};\n" ret += """ /* messages *******************************************************************/ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" namewidth = max(len(msg.name) for msg in just_structs_msg(typs)) for msg in just_structs_msg(typs): ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid}," if comment := c_vercomment(msg.msgver): ret += " " + comment ret += "\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" for msg in just_structs_msg(typs): ret += "\n" if comment := c_vercomment(msg.msgver): ret += comment + "\n" ret += c_typename(idprefix, msg) + " {" if not msg.members: ret += "};\n" continue ret += "\n" all_the_same = len(msg.members) == 0 or all( m.ver == msg.members[0].ver for m in msg.members ) typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) if not all_the_same: namewidth = max(len(m.name) for m in msg.members) for member in msg.members: ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" if (not all_the_same) and (comment := c_vercomment(member.ver)): ret += (" " * (namewidth - len(member.name))) + " " + comment ret += "\n" ret += "};\n" return ret def gen_c(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #include #include #include /* for size_t */ #include /* for PRI* macros */ #include /* for memset() */ #include #include "internal.h" """ def used(arg: str) -> str: return arg def unused(arg: str) -> str: return f"UNUSED({arg})" # strings ################################################################## ret += f""" /* strings ********************************************************************/ static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n' ret += "};\n" ret += f""" const char *{idprefix}version_str(enum {idprefix}version ver) {{ assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')}); return version_strs[ver]; }} static const char *msg_type_strs[0x100] = {{ """ id2name: dict[int, str] = {} for msg in just_structs_msg(typs): assert msg.msgid id2name[msg.msgid] = msg.name for n in range(0, 0x100): ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) ret += "};\n" ret += f""" const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ assert(0 <= typ && typ <= 0xFF); return msg_type_strs[typ]; }} """ # validate_* ############################################################### ret += """ /* validate_* *****************************************************************/ static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) /* If needed-net-size overflowed uint32_t, then * there's no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); if (ctx->net_offset > ctx->net_size) return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) /* If needed-host-size overflowed size_t, then there's * no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { for (size_t i = 0; i < cnt; i++) if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) return true; return false; } #define validate_1(ctx) _validate_size_net(ctx, 1) #define validate_2(ctx) _validate_size_net(ctx, 2) #define validate_4(ctx) _validate_size_net(ctx, 4) #define validate_8(ctx) _validate_size_net(ctx, 8) """ for typ in typs: inline = ( " FLATTEN" if (isinstance(typ, Struct) and typ.msgid is not None) else " ALWAYS_INLINE" ) argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{" if typ.name == "d": # SPECIAL # Optimize... maybe the compiler could figure out to do # this, but let's make it obvious. ret += "\n" ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_4(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" ret += "}\n" continue if typ.name == "s": # SPECIAL # Add an extra nul-byte on the host, and validate UTF-8 # (also, similar optimization to "d"). ret += "\n" ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_2(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" ret += "\t\treturn true;\n" ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' ret += "\treturn false;\n" ret += "}\n" continue match typ: case Bitfield(): ret += "\n" all_the_same = all( val.ver == [*typ.names.values()][0].ver for val in typ.names.values() ) if ( all_the_same and (len(typ.bits) == typ.static_size * 8) and all(typ.bitname_is_valid(bitname) for bitname in typ.bits) ): ret += f"\treturn validate_{typ.static_size}(ctx));\n" else: ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" if all_the_same: ret += ( f"\tstatic const {c_typename(idprefix, typ)} mask = 0b" + "".join( "1" if typ.bitname_is_valid(bitname) else "0" for bitname in reversed(typ.bits) ) + ";\n" ) else: ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_verenum(idprefix, 'NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): ret += ( f"\t\t[{c_verenum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( "1" if typ.bitname_is_valid(bitname, ver) else "0" for bitname in reversed(typ.bits) ) + ",\n" ) ret += "\t};\n" ret += f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n" ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += f"\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n' ret += "\t\t val & ~mask);\n" ret += "\treturn false;\n" case Struct(): if len(typ.members) == 0: ret += "\n\treturn false;\n" ret += "}\n" continue if typ.name == "stat": # SPECIAL ret += f"\n\tuint32_t size_offset = ctx->net_offset;" prefix0 = "\treturn " prefix1 = "\t || " prefix2 = "\t " struct_versions = typ.members[0].ver prefix = prefix0 prev_size: int | None = None for member in typ.members: ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt is not None: assert prev_size ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" else: ret += f"validate_{member.typ.name}(ctx)" if member.max: assert member.static_size ret += f"\n{prefix1}(decode_u{member.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{member.static_size}]) > ({c_typename(idprefix, member.typ)})({member.max})" ret += f'\n{prefix2}\t? lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%"PRIu{member.static_size*8}" > %"PRIu{member.static_size*8}")",' ret += f"\n{prefix2}\t\tdecode_u{member.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{member.static_size}]), ({c_typename(idprefix, member.typ)})({member.max}))" ret += f"\n{prefix2}\t: false)" if member.ver != struct_versions: ret += " )" prefix = prefix1 prev_size = member.static_size if typ.name == "stat": # SPECIAL assert typ.members[0].static_size ret += f"\n{prefix1}((uint32_t)decode_u{typ.members[0].static_size*8}le(&ctx->net_bytes[size_offset]) != ctx->net_offset - size_offset" ret += f'\n{prefix2}\t? lib9p_error(ctx->ctx, LINUX_EBADMSG, "stat size does not match stat contents")' ret += f"\n{prefix2}\t: false)" ret += ";\n" ret += "}\n" # unmarshal_* ############################################################## ret += """ /* unmarshal_* ****************************************************************/ static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 1; } static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 2; } static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 4; } static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 8; } """ for typ in typs: inline = ( " FLATTEN" if (isinstance(typ, Struct) and typ.msgid is not None) else " ALWAYS_INLINE" ) argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" match typ: case Bitfield(): ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" case Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" if typ.members: struct_versions = typ.members[0].ver for member in typ.members: if typ.name == "stat" and member.name == "stat_size": # SPECIAL ret += f"\tctx->net_offset += {member.static_size};\n" continue ret += "\t" prefix = "\t" if member.ver != struct_versions: ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " prefix = "\t\t" if member.cnt: if member.ver != struct_versions: ret += f"{{\n{prefix}" ret += f"out->{member.name} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" if typ.name in ["d", "s"]: # SPECIAL ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" else: ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" if member.ver != struct_versions: ret += "\t}\n" else: ret += ( f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" ) if typ.name == "s": # SPECIAL ret += "\tctx->extra++;\n" ret += "\tout->utf8[out->len] = '\\0';\n" ret += "}\n" # marshal_* ################################################################ ret += """ /* marshal_* ******************************************************************/ static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), ctx->ctx->max_msg_size); return true; } static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) return _marshal_too_large(ctx); ctx->net_bytes[ctx->net_offset] = *val; ctx->net_offset += 1; return false; } static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) return _marshal_too_large(ctx); encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 2; return false; } static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) return true; encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 4; return false; } static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) return true; encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); ctx->net_offset += 8; return false; } """ for typ in typs: inline = ( " FLATTEN" if (isinstance(typ, Struct) and typ.msgid is not None) else " ALWAYS_INLINE" ) argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" match typ: case Bitfield(): ret += "\n" ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" case Struct(): if len(typ.members) == 0: ret += "\n\treturn false;\n" ret += "}\n" continue if typ.name == "stat": # SPECIAL ret += "\n\tuint32_t size_offset = ctx->net_offset;" prefix0 = "\treturn " prefix1 = "\t || " prefix2 = "\t " struct_versions = typ.members[0].ver prefix = prefix0 for member in typ.members: if typ.name == "stat" and member.name == "stat_size": # SPECIAL: assert member.static_size ret += f"\n{prefix }(ctx->net_offset + {member.static_size} > ctx->ctx->max_msg_size" ret += f"\n{prefix2}\t? _marshal_too_large(ctx)" ret += f"\n{prefix2}\t: ({{ ctx->net_offset += {member.static_size}; false; }}))" prefix = prefix1 continue ret += f"\n{prefix}" if member.ver != struct_versions: ret += "( " + c_vercond(idprefix, member.ver) + " && " if member.cnt: ret += "({" ret += f"\n{prefix2}\tbool err = false;" ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" if typ.name in ["d", "s"]: # SPECIAL ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);" else: ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" ret += f"\n{prefix2}\terr;" ret += f"\n{prefix2}}})" else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" if member.ver != struct_versions: ret += " )" prefix = prefix1 if typ.name == "stat": # SPECIAL assert typ.members[0].static_size ret += f"\n{prefix1}((ctx->net_offset - size_offset > UINT16_MAX)" ret += f'\n{prefix2}\t? lib9p_error(ctx->ctx, LINUX_ERANGE, "stat object too large")' ret += f"\n{prefix2}\t: ({{ encode_u{typ.members[0].static_size*8}le((uint{typ.members[0].static_size*8}_t)(ctx->net_offset - size_offset), &ctx->net_bytes[size_offset]);" ret += f"\n{prefix2} false; }}))" ret += ";\n" ret += "}\n" # vtables ################################################################## ret += f""" /* vtables ********************************************************************/ #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ .basesize = sizeof(struct {idprefix}msg_##typ), \\ .validate = validate_##typ, \\ .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ .marshal = (_marshal_fn_t)marshal_##typ, \\ }} struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{ """ ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n" for msg in just_structs_msg(typs): if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL ret += f"\t\t_MSG({msg.name}),\n" ret += "\t}},\n" for ver in sorted(versions): ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n" for msg in just_structs_msg(typs): if ver not in msg.msgver: continue ret += f"\t\t_MSG({msg.name}),\n" ret += "\t}},\n" ret += "};\n" ############################################################################ return ret ################################################################################ class Parser: cache: dict[str, tuple[str, list[Bitfield | Struct]]] = {} def parse_file(self, filename: str) -> tuple[str, list[Bitfield | Struct]]: filename = os.path.normpath(filename) if filename not in self.cache: def get_include(other_filename: str) -> tuple[str, list[Bitfield | Struct]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] def all(self) -> tuple[set[str], list[Bitfield | Struct]]: ret_versions: set[str] = set() ret_typs: dict[str, Bitfield | Struct] = {} for version, typs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {repr(version)}") ret_versions.add(version) for typ in typs: if typ.name in ret_typs: if typ != ret_typs[typ.name]: raise ValueError(f"duplicate type name {repr(typ.name)}") else: ret_typs[typ.name] = typ return ret_versions, list(ret_typs.values()) if __name__ == "__main__": import sys if len(sys.argv) < 2: raise ValueError("requires at least 1 .txt filename") parser = Parser() for txtname in sys.argv[1:]: parser.parse_file(txtname) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: fh.write(gen_h("lib9p_", versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: fh.write(gen_c("lib9p_", versions, typs))