#!/usr/bin/env python # lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files # defining 9P protocol variants. # # Copyright (C) 2024 Luke T. Shumaker # SPDX-Licence-Identifier: AGPL-3.0-or-later import enum import os.path import re from abc import ABC, abstractmethod from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # Types ######################################################################## class Primitive(enum.Enum): u8 = 1 u16 = 2 u32 = 4 u64 = 8 @property def in_versions(self) -> set[str]: return set() @property def name(self) -> str: return str(self.value) @property def static_size(self) -> int: return self.value class Number: name: str in_versions: set[str] prim: Primitive def __init__(self) -> None: self.in_versions = set() @property def static_size(self) -> int: return self.prim.static_size class BitfieldVal: name: str in_versions: set[str] val: str def __init__(self) -> None: self.in_versions = set() class Bitfield: name: str in_versions: set[str] prim: Primitive bits: list[str] # bitnames names: dict[str, BitfieldVal] # bits *and* aliases def __init__(self) -> None: self.in_versions = set() self.names = {} @property def static_size(self) -> int: return self.prim.static_size def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool: """Return whether the given bit is valid in the given protocol version. """ bitname = self.bits[bit] if isinstance(bit, int) else bit assert bitname in self.bits if not bitname: return False if bitname.startswith("_"): return False if ver and (ver not in self.names[bitname].in_versions): return False return True class ExprLit: val: int def __init__(self, val: int) -> None: self.val = val class ExprSym: name: str def __init__(self, name: str) -> None: self.name = name class ExprOp: op: Literal["-", "+"] def __init__(self, op: Literal["-", "+"]) -> None: self.op = op class Expr: tokens: list[ExprLit | ExprSym | ExprOp] def __init__(self) -> None: self.tokens = [] def __bool__(self) -> bool: return len(self.tokens) > 0 class StructMember: # from left-to-right when parsing cnt: str | None = None name: str typ: 'Type' max: Expr val: Expr in_versions: set[str] @property def static_size(self) -> int | None: if self.cnt: return None return self.typ.static_size class Struct: name: str in_versions: set[str] members: list[StructMember] def __init__(self) -> None: self.in_versions = set() @property def static_size(self) -> int | None: size = 0 for member in self.members: msize = member.static_size if msize is None: return None size += msize return size class Message(Struct): @property def msgid(self) -> int: assert len(self.members) >= 3 assert self.members[1].name == "typ" assert self.members[1].static_size == 1 assert self.members[1].val assert len(self.members[1].val.tokens) == 1 assert isinstance(self.members[1].val.tokens[0], ExprLit) return self.members[1].val.tokens[0].val Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message #type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13 T = TypeVar("T", Number, Bitfield, Struct, Message) # Parse *.9p ################################################################### re_priname = "(?:1|2|4|8)" # primitive names re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names re_impname = r"(?:\*|" + re_symname + ")" # names we can import re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)" re_bitspec_bit = f"(?P[0-9]+)\\s*=\\s*(?P{re_symname})" re_bitspec_alias = f"(?P{re_symname})\\s*=\\s*(?P\\S+)" re_memberspec = f"(?:(?P{re_symname})\\*\\()?(?P{re_symname})\\[(?P{re_memtype})(?:,max=(?P{re_expr})|,val=(?P{re_expr}))*\\]\\)?" def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None: spec = spec.strip() bit: int | None val: BitfieldVal if m := re.fullmatch(re_bitspec_bit, spec): bit = int(m.group("bit")) name = m.group("name") val = BitfieldVal() val.name = name val.val = f"1<<{bit}" val.in_versions.add(ver) if bit < 0 or bit >= len(bf.bits): raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") if bf.bits[bit]: raise ValueError(f"{bf.name}: bit {bit} already assigned") bf.bits[bit] = val.name elif m := re.fullmatch(re_bitspec_alias, spec): name = m.group("name") valstr = m.group("val") val = BitfieldVal() val.name = name val.val = valstr val.in_versions.add(ver) else: raise SyntaxError(f"invalid bitfield spec {repr(spec)}") if val.name in bf.names: raise ValueError(f"{bf.name}: name {val.name} already assigned") bf.names[val.name] = val def parse_expr(expr: str) -> Expr: assert re.fullmatch(re_expr, expr) ret = Expr() for tok in re.split("([-+])", expr): if tok == "-" or tok == "+": ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))] elif re.fullmatch("[0-9]+", tok): ret.tokens += [ExprLit(int(tok))] else: ret.tokens += [ExprSym(tok)] return ret def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None: for spec in specs.split(): m = re.fullmatch(re_memberspec, spec) if not m: raise SyntaxError(f"invalid member spec {repr(spec)}") member = StructMember() member.in_versions = {ver} member.name = m.group("name") if any(x.name == member.name for x in struct.members): raise ValueError(f"duplicate member name {repr(member.name)}") if m.group("typ") not in env: raise NameError(f"Unknown type {repr(m.group(2))}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): if len(struct.members) == 0 or struct.members[-1].name != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") if not isinstance(struct.members[-1].typ, Primitive): raise ValueError(f"list count must be an integer type: {repr(cnt)}") member.cnt = cnt if maxstr := m.group("max"): if (not isinstance(member.typ, Primitive)) or member.cnt: raise ValueError("',max=' may only be specified on a non-repeated atom") member.max = parse_expr(maxstr) else: member.max = Expr() if valstr := m.group("val"): if (not isinstance(member.typ, Primitive)) or member.cnt: raise ValueError("',val=' may only be specified on a non-repeated atom") member.val = parse_expr(valstr) else: member.val = Expr() struct.members += [member] def re_string(grpname: str) -> str: return f'"(?P<{grpname}>[^"]*)"' re_line_version = f"version\\s+{re_string('version')}" re_line_import = f"from\\s+(?P\\S+)\\s+import\\s+(?P{re_impname}(?:\\s*,\\s*{re_impname})*)" re_line_num = f"num\\s+(?P{re_symname})\\s*=\\s*(?P{re_priname})" re_line_bitfield = f"bitfield\\s+(?P{re_symname})\\s*=\\s*(?P{re_priname})" re_line_bitfield_ = ( f"bitfield\\s+(?P{re_symname})\\s*\\+=\\s*{re_string('member')}" ) re_line_struct = ( f"struct\\s+(?P{re_symname})\\s*(?P\\+?=)\\s*{re_string('members')}" ) re_line_msg = ( f"msg\\s+(?P{re_msgname})\\s*(?P\\+?=)\\s*{re_string('members')}" ) re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg def parse_file( filename: str, get_include: Callable[[str], tuple[str, list[Type]]] ) -> tuple[str, list[Type]]: version: str | None = None env: dict[str, Type] = { "1": Primitive.u8, "2": Primitive.u16, "4": Primitive.u32, "8": Primitive.u64, } def get_type(name: str, tc: type[T]) -> T: nonlocal env if name not in env: raise NameError(f"Unknown type {repr(name)}") ret = env[name] if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}") return ret with open(filename, "r") as fh: prev: Type | None = None for line in fh: line = line.split("#", 1)[0].rstrip() if not line: continue if m := re.fullmatch(re_line_version, line): if version: raise SyntaxError("must have exactly 1 version line") version = m.group("version") continue if not version: raise SyntaxError("must have exactly 1 version line") if m := re.fullmatch(re_line_import, line): other_version, other_typs = get_include(m.group("file")) for symname in m.group("syms").split(sep=","): symname = symname.strip() for typ in other_typs: if typ.name == symname or symname == "*": match typ: case Primitive(): pass case Number(): typ.in_versions.add(version) case Bitfield(): typ.in_versions.add(version) for val in typ.names.values(): if other_version in val.in_versions: val.in_versions.add(version) case Struct(): # and Message() typ.in_versions.add(version) for member in typ.members: if other_version in member.in_versions: member.in_versions.add(version) env[typ.name] = typ elif m := re.fullmatch(re_line_num, line): num = Number() num.name = m.group("name") num.in_versions.add(version) prim = env[m.group("prim")] assert isinstance(prim, Primitive) num.prim = prim env[num.name] = num prev = num elif m := re.fullmatch(re_line_bitfield, line): bf = Bitfield() bf.name = m.group("name") bf.in_versions.add(version) prim = env[m.group("prim")] assert isinstance(prim, Primitive) bf.prim = prim bf.bits = (prim.static_size * 8) * [""] env[bf.name] = bf prev = bf elif m := re.fullmatch(re_line_bitfield_, line): bf = get_type(m.group("name"), Bitfield) parse_bitspec(version, bf, m.group("member")) prev = bf elif m := re.fullmatch(re_line_struct, line): match m.group("op"): case "=": struct = Struct() struct.name = m.group("name") struct.in_versions.add(version) struct.members = [] parse_members(version, env, struct, m.group("members")) env[struct.name] = struct prev = struct case "+=": struct = get_type(m.group("name"), Struct) parse_members(version, env, struct, m.group("members")) prev = struct elif m := re.fullmatch(re_line_msg, line): match m.group("op"): case "=": msg = Message() msg.name = m.group("name") msg.in_versions.add(version) msg.members = [] parse_members(version, env, msg, m.group("members")) env[msg.name] = msg prev = msg case "+=": msg = get_type(m.group("name"), Message) parse_members(version, env, msg, m.group("members")) prev = msg elif m := re.fullmatch(re_line_cont, line): match prev: case Bitfield(): parse_bitspec(version, prev, m.group("specs")) case Struct(): # and Message() parse_members(version, env, prev, m.group("specs")) case _: raise SyntaxError( "continuation line must come after a bitfield, struct, or msg line" ) else: raise SyntaxError(f"invalid line {repr(line)}") if not version: raise SyntaxError("must have exactly 1 version line") typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] for typ in [typ for typ in typs if isinstance(typ, Struct)]: valid_syms = ["end", *["&" + m.name for m in typ.members]] for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, ExprSym) and tok.name not in valid_syms: raise ValueError( f"{typ.name}.{member.name}: invalid sym: {tok.name}" ) return version, typs # Generate C ################################################################### def c_ver_enum(idprefix: str, ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" def c_ver_ifdef(versions: set[str]) -> str: return " || ".join( f"defined(CONFIG_9P_ENABLE_{v.replace('.', '_')})" for v in sorted(versions) ) def c_ver_cond(idprefix: str, versions: set[str]) -> str: if len(versions) == 1: return f"(ctx->ctx->version=={c_ver_enum(idprefix, next(v for v in versions))})" return ( "( " + (" || ".join(c_ver_cond(idprefix, {v}) for v in sorted(versions))) + " )" ) def c_typename(idprefix: str, typ: Type) -> str: match typ: case Primitive(): return f"uint{typ.value*8}_t" case Number(): return f"{idprefix}{typ.name}_t" case Bitfield(): return f"{idprefix}{typ.name}_t" case Message(): return f"struct {idprefix}msg_{typ.name}" case Struct(): return f"struct {idprefix}{typ.name}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") _ifdef_stack: list[str | None] = [] def ifdef_push(n: int, _newval: str) -> str: # Grow the stack as needed global _ifdef_stack while len(_ifdef_stack) < n: _ifdef_stack += [None] # Set some variables parentval: str | None = None for x in _ifdef_stack[:-1]: if x is not None: parentval = x oldval = _ifdef_stack[-1] newval: str | None = _newval if newval == parentval: newval = None # Put newval on the stack. _ifdef_stack[-1] = newval # Build output. ret = "" if newval != oldval: if oldval is not None: ret += f"#endif /* {oldval} */\n" if newval is not None: ret += f"#if {newval}\n" return ret def ifdef_pop(n: int) -> str: global _ifdef_stack ret = "" while len(_ifdef_stack) > n: if _ifdef_stack[-1] is not None: ret += f"#endif /* {_ifdef_stack[-1]} */\n" _ifdef_stack = _ifdef_stack[:-1] return ret def gen_h(idprefix: str, versions: set[str], typs: list[Type]) -> str: global _ifdef_stack _ifdef_stack = [] ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef _LIB9P_9P_H_ # error Do not include directly; include instead #endif #include /* for uint{{n}}_t types */ """ ret += f""" /* versions *******************************************************************/ enum {idprefix}version {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: ret += ifdef_push(1, c_ver_ifdef({ver})) ret += f"\t{c_ver_enum(idprefix, ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += ifdef_pop(0) ret += f"\t{c_ver_enum(idprefix, 'NUM')},\n" ret += "};\n" ret += "\n" ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" ret += """ /* non-message types **********************************************************/ """ for typ in [typ for typ in typs if not isinstance(typ, Message)]: ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) match typ: case Number(): ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n" case Bitfield(): ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n" names = [ *reversed( [typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))] ), "", *[k for k in typ.names if k not in typ.bits], ] namewidth = max(len(name) for name in names) ret += "\n" for name in names: if name == "": ret += "\n" elif name.startswith(" "): ret += ifdef_push(2, c_ver_ifdef(typ.in_versions)) sp = ' '*(len('# define ')+len(idprefix)+len(typ.name)+1+namewidth+2-len("/* unused")) ret += f"/* unused{sp}(({c_typename(idprefix, typ)})(1<<{name[1:]})) */\n" else: ret += ifdef_push(2, c_ver_ifdef(typ.names[name].in_versions)) if name.startswith("_"): c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}" else: c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}" sp1 = ' ' if _ifdef_stack[-1] else '' sp2 = ' ' if _ifdef_stack[-1] else ' ' sp3 = ' '*(2+namewidth-len(name)) ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(idprefix, typ)})({typ.names[name].val}))\n" ret += ifdef_pop(1) case Struct(): typewidth = max(len(c_typename(idprefix, m.typ)) for m in typ.members) ret += c_typename(idprefix, typ) + " {\n" for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) c_type = c_typename(idprefix, member.typ) if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL c_type = "char" ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) ret += """ /* messages *******************************************************************/ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message)) for msg in [msg for msg in typs if isinstance(msg, Message)]: ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" ret += ifdef_pop(0) ret += "};\n" ret += "\n" ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" for msg in [msg for msg in typs if isinstance(msg, Message)]: ret += "\n" ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) ret += c_typename(idprefix, msg) + " {" if not msg.members: ret += "};\n" continue ret += "\n" typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) for member in msg.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) return ret def c_expr(expr: Expr) -> str: ret: list[str] = [] for tok in expr.tokens: match tok: case ExprOp(): ret += [tok.op] case ExprLit(): ret += [str(tok.val)] case ExprSym(name="end"): ret += ["ctx->net_offset"] case ExprSym(): ret += [f"_{tok.name[1:]}_offset"] return " ".join(ret) def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str: global _ifdef_stack _ifdef_stack = [] ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #include #include #include /* for size_t */ #include /* for PRI* macros */ #include /* for memset() */ #include #include "internal.h" """ def used(arg: str) -> str: return arg def unused(arg: str) -> str: return f"UNUSED({arg})" # strings ################################################################## ret += f""" /* strings ********************************************************************/ static const char *version_strs[{c_ver_enum(idprefix, 'NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: ret += ifdef_push(1, c_ver_ifdef({ver})) ret += f'\t[{c_ver_enum(idprefix, ver)}] = "{ver}",\n' ret += ifdef_pop(0) ret += "};\n" ret += f""" const char *{idprefix}version_str(enum {idprefix}version ver) {{ assert(0 <= ver && ver < {c_ver_enum(idprefix, 'NUM')}); return version_strs[ver]; }} static const char *msg_type_strs[0x100] = {{ """ id2name: dict[int, str] = {} for msg in [msg for msg in typs if isinstance(msg, Message)]: id2name[msg.msgid] = msg.name for n in range(0, 0x100): ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) ret += "};\n" ret += f""" const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ assert(0 <= typ && typ <= 0xFF); return msg_type_strs[typ]; }} """ # validate_* ############################################################### ret += """ /* validate_* *****************************************************************/ static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) /* If needed-net-size overflowed uint32_t, then * there's no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); if (ctx->net_offset > ctx->net_size) return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) /* If needed-host-size overflowed size_t, then there's * no way that actual-net-size will live up to * that. */ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); return false; } static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { for (size_t i = 0; i < cnt; i++) if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) return true; return false; } #define validate_1(ctx) _validate_size_net(ctx, 1) #define validate_2(ctx) _validate_size_net(ctx, 2) #define validate_4(ctx) _validate_size_net(ctx, 4) #define validate_8(ctx) _validate_size_net(ctx, 8) """ for typ in typs: inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" argfn = unused if (isinstance(typ, Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" if typ.name == "d": # SPECIAL # Optimize... maybe the compiler could figure out to do # this, but let's make it obvious. ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_4(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" ret += "}\n" continue if typ.name == "s": # SPECIAL # Add an extra nul-byte on the host, and validate UTF-8 # (also, similar optimization to "d"). ret += "\tuint32_t base_offset = ctx->net_offset;\n" ret += "\tif (validate_2(ctx))\n" ret += "\t\treturn true;\n" ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" ret += "\t\treturn true;\n" ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' ret += "\treturn false;\n" ret += "}\n" continue match typ: case Number(): ret += f"\treturn validate_{typ.prim.name}(ctx);\n" case Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_ver_enum(idprefix, 'NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): ret += ifdef_push(2, c_ver_ifdef({ver})) ret += ( f"\t\t[{c_ver_enum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( "1" if typ.bit_is_valid(bitname, ver) else "0" for bitname in reversed(typ.bits) ) + ",\n" ) ret += ifdef_pop(1) ret += "\t};\n" ret += ( f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n" ) ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += f"\tif (val & ~mask)\n" ret += "\t\treturn lib9p_errorf(ctx->ctx,\n" ret += f'\t\t LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n' ret += "\t\t val & ~mask);\n" ret += "\treturn false;\n" case Struct(): # and Message() if len(typ.members) == 0: ret += "\treturn false;\n" ret += "}\n" continue # Pass 1 for member in typ.members: if member.max or member.val: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(idprefix, member.typ)} {member.name};\n" ret += ifdef_pop(1) # Pass 2 mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, ExprSym) and tok.name.startswith("&"): if tok.name[1:] not in mark_offset: ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" mark_offset.add(tok.name[1:]) # Pass 3 ret += "\treturn false\n" prev_size: int | None = None for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(idprefix, member.in_versions) + " && " if member.cnt is not None: assert prev_size ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" else: if member.max or member.val: ret += "(" if member.name in mark_offset: ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" if member.name in mark_offset: ret += "; })" if member.max or member.val: bytes = member.static_size assert bytes bits = bytes * 8 ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" prev_size = member.static_size # Pass 4 for member in typ.members: if member.max: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\n\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n" ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n' if member.val: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\n\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n" ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n' ret += ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += ifdef_pop(0) # # unmarshal_* ############################################################## # ret += """ # /* unmarshal_* ****************************************************************/ # static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { # *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 1; # } # static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { # *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 2; # } # static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { # *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 4; # } # static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { # *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 8; # } # """ # for typ in typs: # inline = ( # " FLATTEN" # if (isinstance(typ, Struct) and typ.msgid is not None) # else " ALWAYS_INLINE" # ) # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used # ret += "\n" # ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" # match typ: # case Bitfield(): # ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" # case Struct(): # ret += "\tmemset(out, 0, sizeof(*out));\n" # if typ.members: # struct_versions = typ.members[0].ver # for member in typ.members: # if member.valexpr: # ret += f"\tctx->net_offset += {member.static_size};\n" # continue # ret += "\t" # prefix = "\t" # if member.ver != struct_versions: # ret += "if ( " + c_ver_cond(idprefix, member.ver) + " ) " # prefix = "\t\t" # if member.cnt: # if member.ver != struct_versions: # ret += f"{{\n{prefix}" # ret += f"out->{member.name} = ctx->extra;\n" # ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" # ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" # if typ.name in ["d", "s"]: # SPECIAL # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" # else: # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" # if member.ver != struct_versions: # ret += "\t}\n" # else: # ret += ( # f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" # ) # if typ.name == "s": # SPECIAL # ret += "\tctx->extra++;\n" # ret += "\tout->utf8[out->len] = '\\0';\n" # ret += "}\n" # # marshal_* ################################################################ # ret += """ # /* marshal_* ******************************************************************/ # static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { # lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", # (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", # ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), # ctx->ctx->max_msg_size); # return true; # } # static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { # if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) # return _marshal_too_large(ctx); # ctx->net_bytes[ctx->net_offset] = *val; # ctx->net_offset += 1; # return false; # } # static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { # if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) # return _marshal_too_large(ctx); # encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 2; # return false; # } # static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { # if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) # return true; # encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 4; # return false; # } # static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { # if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) # return true; # encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); # ctx->net_offset += 8; # return false; # } # """ # for typ in typs: # inline = ( # " FLATTEN" # if (isinstance(typ, Struct) and typ.msgid is not None) # else " ALWAYS_INLINE" # ) # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used # ret += "\n" # ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" # match typ: # case Bitfield(): # ret += "\n" # ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" # case Struct(): # if len(typ.members) == 0: # ret += "\n\treturn false;\n" # ret += "}\n" # continue # mark_offset = set() # for member in typ.members: # if member.valexpr: # if member.name not in mark_offset: # ret += f"\n\tuint32_t _{member.name}_offset;" # mark_offset.add(member.name) # for tok in member.valexpr: # if ( # isinstance(tok, ExprVal) # and tok.name.startswith("&") # and tok.name[1:] not in mark_offset # ): # ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" # mark_offset.add(tok.name[1:]) # prefix0 = "\treturn " # prefix1 = "\t || " # prefix2 = "\t " # struct_versions = typ.members[0].ver # prefix = prefix0 # for member in typ.members: # ret += f"\n{prefix}" # if member.ver != struct_versions: # ret += "( " + c_ver_cond(idprefix, member.ver) + " && " # if member.name in mark_offset: # ret += f"({{ _{member.name}_offset = ctx->net_offset; " # if member.cnt: # ret += "({" # ret += f"\n{prefix2}\tbool err = false;" # ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" # if typ.name in ["d", "s"]: # SPECIAL # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);" # else: # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" # ret += f"\n{prefix2}\terr;" # ret += f"\n{prefix2}}})" # elif member.valexpr: # assert member.static_size # ret += ( # f"({{ ctx->net_offset += {member.static_size}; false; }})" # ) # else: # ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" # if member.name in mark_offset: # ret += "; })" # if member.ver != struct_versions: # ret += " )" # prefix = prefix1 # for member in typ.members: # if member.valexpr: # assert member.static_size # ret += f"\n{prefix}" # ret += f"({{ encode_u{member.static_size*8}le(" # for tok in member.valexpr: # match tok: # case ExprOp(): # ret += f" {tok.op}" # case ExprVal(name="end"): # ret += " ctx->net_offset" # case ExprVal(): # ret += f" _{tok.name[1:]}_offset" # ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})" # ret += ";\n" # ret += "}\n" # # vtables ################################################################## # ret += f""" # /* vtables ********************************************************************/ # #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ # .basesize = sizeof(struct {idprefix}msg_##typ), \\ # .validate = validate_##typ, \\ # .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ # .marshal = (_marshal_fn_t)marshal_##typ, \\ # }} # struct _vtable_version _{idprefix}vtables[{c_ver_enum(idprefix, 'NUM')}] = {{ # """ # ret += f"\t[{c_ver_enum(idprefix, 'unknown')}] = {{ .msgs = {{\n" # for msg in just_structs_msg(typs): # if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL # ret += f"\t\t_MSG({msg.name}),\n" # ret += "\t}},\n" # for ver in sorted(versions): # ret += f"\t[{c_ver_enum(idprefix, ver)}] = {{ .msgs = {{\n" # for msg in just_structs_msg(typs): # if ver not in msg.msgver: # continue # ret += f"\t\t_MSG({msg.name}),\n" # ret += "\t}},\n" # ret += "};\n" ############################################################################ return ret ################################################################################ class Parser: cache: dict[str, tuple[str, list[Type]]] = {} def parse_file(self, filename: str) -> tuple[str, list[Type]]: filename = os.path.normpath(filename) if filename not in self.cache: def get_include(other_filename: str) -> tuple[str, list[Type]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] def all(self) -> tuple[set[str], list[Type]]: ret_versions: set[str] = set() ret_typs: dict[str, Type] = {} for version, typs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {repr(version)}") ret_versions.add(version) for typ in typs: if typ.name in ret_typs: if typ != ret_typs[typ.name]: raise ValueError(f"duplicate type name {repr(typ.name)}") else: ret_typs[typ.name] = typ return ret_versions, list(ret_typs.values()) if __name__ == "__main__": import sys if len(sys.argv) < 2: raise ValueError("requires at least 1 .9p filename") parser = Parser() for txtname in sys.argv[1:]: parser.parse_file(txtname) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: fh.write(gen_h("lib9p_", versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: fh.write(gen_c("lib9p_", versions, typs))