diff options
Diffstat (limited to 'lib9p/9p.gen')
-rwxr-xr-x | lib9p/9p.gen | 911 |
1 files changed, 911 insertions, 0 deletions
diff --git a/lib9p/9p.gen b/lib9p/9p.gen new file mode 100755 index 0000000..f974dd1 --- /dev/null +++ b/lib9p/9p.gen @@ -0,0 +1,911 @@ +#!/usr/bin/env python +# lib9p/9p.gen - Generate C marshalers/unmarshalers for .txt files +# defining 9P protocol variants. +# +# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-Licence-Identifier: AGPL-3.0-or-later + +import enum +import os.path +import re +from typing import Callable, Sequence + +# This strives to be "general-purpose" in that it just acts on the +# *.txt inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + +# Parse *.txt ################################################################## + + +class Atom(enum.Enum): + u8 = 1 + u16 = 2 + u32 = 4 + u64 = 8 + + @property + def name(self) -> str: + return str(self.value) + + @property + def static_size(self) -> int: + return self.value + + +class BitfieldVal: + name: str + val: str + ver: set[str] + + def __init__(self) -> None: + self.ver = set() + + +class Bitfield: + name: str + bits: list[str] + names: dict[str, BitfieldVal] + + @property + def static_size(self) -> int: + return int((len(self.bits) + 7) / 8) + + def bitname_is_valid(self, bitname: str, ver: str | None = None) -> bool: + assert bitname in self.bits + if not bitname: + return False + if bitname.startswith("_"): + return False + if ver and (ver not in self.names[bitname].ver): + return False + return True + + +# `msgid/structname = "member1 member2..."` +# `structname = "member1 member2..."` +# `structname += "member1 member2..."` +class Struct: + msgid: int | None = None + msgver: set[str] + name: str + members: list["Member"] + + def __init__(self) -> None: + self.msgver = set() + + @property + def static_size(self) -> int | None: + size = 0 + for member in self.members: + msize = member.static_size + if msize is None: + return None + size += msize + return size + + +# `cnt*(name[typ])` +# the `cnt*(...)` wrapper is optional +class Member: + cnt: str | None = None + name: str + typ: Atom | Bitfield | Struct + ver: set[str] + + @property + def static_size(self) -> int | None: + if self.cnt: + return None + return self.typ.static_size + + +re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" +re_memberspec = ( + f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>.*)\\]\\)?" +) + + +def parse_members( + ver: str, + env: dict[str, Atom | Bitfield | Struct], + existing: list[Member], + specs: str, +) -> list[Member]: + ret = existing + for spec in specs.split(): + m = re.fullmatch(re_memberspec, spec) + if not m: + raise SyntaxError(f"invalid member spec {repr(spec)}") + + member = Member() + member.ver = {ver} + + member.name = m.group("name") + if any(x.name == member.name for x in ret): + raise ValueError(f"duplicate member name {repr(member.name)}") + + if m.group("typ") not in env: + raise NameError(f"Unknown type {repr(m.group(2))}") + member.typ = env[m.group("typ")] + + if cnt := m.group("cnt"): + if len(ret) == 0 or ret[-1].name != cnt: + raise ValueError(f"list count must be previous item: {repr(cnt)}") + if not isinstance(ret[-1].typ, Atom): + raise ValueError(f"list count must be an integer type: {repr(cnt)}") + member.cnt = cnt + + ret += [member] + return ret + + +re_version = r'version\s+"(?P<version>[^"]+)"' +re_import = r"from\s+(?P<file>\S+)\s+import\s+(?P<syms>\S+(?:\s*,\s*\S+)*)\s*" +re_structspec = ( + r'(?:(?P<msgid>[0-9]+)/)?(?P<name>\S+)\s*(?P<op>\+?=)\s*"(?P<members>[^"]*)"' +) +re_structspec_cont = r'\s+"(?P<members>[^"]*)"' +re_bitfieldspec = r"bitfield\s+(?P<name>\S+)\s+(?P<size>[0-9]+)" +re_bitfieldspec_bit = r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<bit>[0-9]+)/(?P<name>\S+)" +re_bitfieldspec_alias = ( + r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<name>\S+)\s*=\s*(?P<val>.*)" +) + + +def parse_file( + filename: str, get_include: Callable[[str], tuple[str, list[Bitfield | Struct]]] +) -> tuple[str, list[Bitfield | Struct]]: + version: str | None = None + env: dict[str, Atom | Bitfield | Struct] = { + "1": Atom.u8, + "2": Atom.u16, + "4": Atom.u32, + "8": Atom.u64, + } + with open(filename, "r") as fh: + prev: Struct | Bitfield | None = None + for line in fh: + line = line.split("#", 1)[0].rstrip() + if not line: + continue + if m := re.fullmatch(re_version, line): + if version: + raise SyntaxError("must have exactly 1 version line") + version = m.group("version") + elif m := re.fullmatch(re_import, line): + if not version: + raise SyntaxError("must have exactly 1 version line") + other_version, other_typs = get_include(m.group("file")) + for symname in m.group("syms").split(sep=","): + symname = symname.strip() + for typ in other_typs: + if typ.name == symname or symname == "*": + match typ: + case Bitfield(): + for val in typ.names.values(): + if other_version in val.ver: + val.ver.add(version) + case Struct(): + if typ.msgid: + typ.msgver.add(version) + for member in typ.members: + if other_version in member.ver: + member.ver.add(version) + env[typ.name] = typ + elif m := re.fullmatch(re_structspec, line): + if not version: + raise SyntaxError("must have exactly 1 version line") + if m.group("op") == "+=" and m.group("msgid"): + raise SyntaxError("cannot += to a message that is not yet defined") + match m.group("op"): + case "=": + struct = Struct() + if m.group("msgid"): + struct.msgid = int(m.group("msgid")) + struct.msgver.add(version) + struct.name = m.group("name") + struct.members = parse_members( + version, env, [], m.group("members") + ) + env[struct.name] = struct + prev = struct + case "+=": + if m.group("name") not in env: + raise NameError(f"Unknown type {repr(m.group('name'))}") + _struct = env[m.group("name")] + if not isinstance(_struct, Struct): + raise NameError( + f"Type {repr(_struct.name)} is not a struct" + ) + struct = _struct + struct.members = parse_members( + version, env, struct.members, m.group("members") + ) + prev = struct + elif m := re.fullmatch(re_structspec_cont, line): + if not isinstance(prev, Struct): + raise SyntaxError( + "struct-continuation line must come after a struct line" + ) + assert version + prev.members = parse_members( + version, env, prev.members, m.group("members") + ) + elif m := re.fullmatch(re_bitfieldspec, line): + if not version: + raise SyntaxError("must have exactly 1 version line") + bf = Bitfield() + bf.name = m.group("name") + bf.bits = int(m.group("size")) * [""] + bf.names = {} + if len(bf.bits) not in [8, 16, 32, 64]: + raise ValueError(f"Bitfield {repr(bf.name)} has an unusual size") + env[bf.name] = bf + prev = bf + elif m := re.fullmatch(re_bitfieldspec_bit, line): + if m.group("bitfield"): + if m.group("bitfield") not in env: + raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") + _bf = env[m.group("bitfield")] + if not isinstance(_bf, Bitfield): + raise NameError(f"Type {repr(_bf.name)} is not a bitfield") + bf = _bf + prev = bf + else: + if not isinstance(prev, Bitfield): + raise SyntaxError( + "bitfield-continuation line must come after a bitfield line" + ) + bf = prev + bit = int(m.group("bit")) + name = m.group("name") + if bit < 0 or bit >= len(bf.bits): + raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") + if bf.bits[bit]: + raise ValueError(f"{bf.name}: bit {bit} already assigned") + if name in bf.names: + raise ValueError(f"{bf.name}: name {name} already assigned") + + bf.bits[bit] = name + + assert version + val = BitfieldVal() + val.name = name + val.val = f"1<<{bit}" + val.ver.add(version) + bf.names[name] = val + elif m := re.fullmatch(re_bitfieldspec_alias, line): + if m.group("bitfield"): + if m.group("bitfield") not in env: + raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") + _bf = env[m.group("bitfield")] + if not isinstance(_bf, Bitfield): + raise NameError(f"Type {repr(_bf.name)} is not a bitfield") + bf = _bf + prev = bf + else: + if not isinstance(prev, Bitfield): + raise SyntaxError( + "bitfield-continuation line must come after a bitfield line" + ) + bf = prev + name = m.group("name") + valstr = m.group("val") + if name in bf.names: + raise ValueError(f"{bf.name}: name {name} already assigned") + + assert version + val = BitfieldVal() + val.name = name + val.val = valstr + val.ver.add(version) + bf.names[name] = val + else: + raise SyntaxError(f"invalid line {repr(line)}") + if not version: + raise SyntaxError("must have exactly 1 version line") + + typs = [x for x in env.values() if not isinstance(x, Atom)] + return version, typs + + +# Generate C ################################################################### + + +def c_typename(idprefix: str, typ: Atom | Bitfield | Struct) -> str: + match typ: + case Atom(): + return f"uint{typ.value*8}_t" + case Bitfield(): + return f"{idprefix}{typ.name}_t" + case Struct(): + if typ.msgid is not None: + return f"struct {idprefix}msg_{typ.name}" + return f"struct {idprefix}{typ.name}" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def c_verenum(idprefix: str, ver: str) -> str: + return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" + + +def c_vercomment(versions: set[str]) -> str | None: + if "9P2000" in versions: + return None + return "/* " + (", ".join(sorted(versions))) + " */" + + +def c_vercond(idprefix: str, versions: set[str]) -> str: + if len(versions) == 1: + return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})" + return ( + "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )" + ) + + +def just_structs_all(typs: list[Bitfield | Struct]) -> Sequence[Struct]: + return list(typ for typ in typs if isinstance(typ, Struct)) + + +def just_structs_nonmsg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: + return list(typ for typ in typs if isinstance(typ, Struct) and typ.msgid is None) + + +def just_structs_msg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: + return list( + typ for typ in typs if isinstance(typ, Struct) and typ.msgid is not None + ) + + +def just_bitfields(typs: list[Bitfield | Struct]) -> Sequence[Bitfield]: + return list(typ for typ in typs if isinstance(typ, Bitfield)) + + +def gen_h(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#ifndef _LIB9P_9P_H_ +# error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead +#endif + +#include <stdint.h> /* for uint{{n}}_t types */ +""" + + ret += f""" +/* versions *******************************************************************/ + +enum {idprefix}version {{ +""" + fullversions = ["unknown = 0", *sorted(versions)] + verwidth = max(len(v) for v in fullversions) + for ver in fullversions: + ret += f"\t{c_verenum(idprefix, ver)}," + ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' + ret += f"\t{c_verenum(idprefix, 'NUM')},\n" + ret += "};\n" + ret += "\n" + ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" + + ret += """ +/* non-message types **********************************************************/ +""" + for bf in just_bitfields(typs): + ret += "\n" + ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n" + names = [ + *reversed([bf.bits[n] or f"_UNUSED_{n}" for n in range(0, len(bf.bits))]), + *[k for k in bf.names if k not in bf.bits], + ] + namewidth = max(len(name) for name in names) + + for name in names: + if name.startswith("_"): + cname = f"_{idprefix.upper()}{bf.name.upper()}_{name[1:]}" + else: + cname = f"{idprefix.upper()}{bf.name.upper()}_{name}" + if name in bf.names: + val = bf.names[name].val + else: + assert name.startswith("_UNUSED_") + val = f"1<<{name[len('_UNUSED_'):]}" + ret += f"#define {cname}{' '*(namewidth-len(name))} (({c_typename(idprefix, bf)})({val}))" + if (name in bf.names) and (comment := c_vercomment(bf.names[name].ver)): + ret += " " + comment + ret += "\n" + + for struct in just_structs_nonmsg(typs): + all_the_same = len(struct.members) == 0 or all( + m.ver == struct.members[0].ver for m in struct.members + ) + typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members) + if not all_the_same: + namewidth = max(len(m.name) for m in struct.members) + + ret += "\n" + ret += c_typename(idprefix, struct) + " {\n" + for member in struct.members: + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" + if (not all_the_same) and (comment := c_vercomment(member.ver)): + ret += (" " * (namewidth - len(member.name))) + " " + comment + ret += "\n" + ret += "};\n" + + ret += """ +/* messages *******************************************************************/ + +""" + ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + namewidth = max(len(msg.name) for msg in just_structs_msg(typs)) + for msg in just_structs_msg(typs): + ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid}," + if comment := c_vercomment(msg.msgver): + ret += " " + comment + ret += "\n" + ret += "};\n" + ret += "\n" + ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" + + for msg in just_structs_msg(typs): + ret += "\n" + if comment := c_vercomment(msg.msgver): + ret += comment + "\n" + ret += c_typename(idprefix, msg) + " {" + if not msg.members: + ret += "};\n" + continue + ret += "\n" + + all_the_same = len(msg.members) == 0 or all( + m.ver == msg.members[0].ver for m in msg.members + ) + typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) + if not all_the_same: + namewidth = max(len(m.name) for m in msg.members) + + for member in msg.members: + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" + if (not all_the_same) and (comment := c_vercomment(member.ver)): + ret += (" " * (namewidth - len(member.name))) + " " + comment + ret += "\n" + ret += "};\n" + + return ret + + +def gen_c(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#include <assert.h> +#include <stdbool.h> +#include <stddef.h> /* for size_t */ +#include <inttypes.h> /* for PRI* macros */ +#include <string.h> /* for memset() */ + +#include <lib9p/9p.h> + +#include "internal.h" +""" + + def used(arg: str) -> str: + return arg + + def unused(arg: str) -> str: + return f"UNUSED({arg})" + + # strings ################################################################## + ret += f""" +/* strings ********************************************************************/ + +static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{ +""" + for ver in ["unknown", *sorted(versions)]: + ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n' + ret += "};\n" + ret += f""" +const char *{idprefix}version_str(enum {idprefix}version ver) {{ + assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')}); + return version_strs[ver]; +}} + +static const char *msg_type_strs[0x100] = {{ +""" + id2name: dict[int, str] = {} + for msg in just_structs_msg(typs): + assert msg.msgid + id2name[msg.msgid] = msg.name + for n in range(0, 0x100): + ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) + ret += "};\n" + ret += f""" +const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ + assert(0 <= typ && typ <= 0xFF); + return msg_type_strs[typ]; +}} +""" + + # validate_* ############################################################### + ret += """ +/* validate_* *****************************************************************/ + +static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { + if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) + /* If needed-net-size overflowed uint32_t, then + * there's no way that actual-net-size will live up to + * that. */ + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + if (ctx->net_offset > ctx->net_size) + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + return false; +} + +static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { + if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) + /* If needed-host-size overflowed size_t, then there's + * no way that actual-net-size will live up to + * that. */ + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + return false; +} + +static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, + size_t cnt, size_t max, + _validate_fn_t item_fn, size_t item_host_size) { + if (max && cnt > max) + return lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%zu > %zu)", + cnt, max); + for (size_t i = 0; i < cnt; i++) + if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) + return true; + return false; +} + +#define validate_1(ctx) _validate_size_net(ctx, 1) +#define validate_2(ctx) _validate_size_net(ctx, 2) +#define validate_4(ctx) _validate_size_net(ctx, 4) +#define validate_8(ctx) _validate_size_net(ctx, 8) +""" + for typ in typs: + inline = ( + " FLATTEN" + if (isinstance(typ, Struct) and typ.msgid is not None) + else " ALWAYS_INLINE" + ) + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + ret += "\n" + ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{" + + if typ.name == "d": # SPECIAL + # Optimize... maybe the compiler could figure out to do + # this, but let's make it obvious. + ret += "\n" + ret += "\tuint32_t base_offset = ctx->net_offset;\n" + ret += "\tif (validate_4(ctx))\n" + ret += "\t\treturn true;\n" + ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" + ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" + ret += "}\n" + continue + if typ.name == "s": # SPECIAL + # Add an extra nul-byte on the host, and validate UTF-8 + # (also, similar optimization to "d"). + ret += "\n" + ret += "\tuint32_t base_offset = ctx->net_offset;\n" + ret += "\tif (validate_2(ctx))\n" + ret += "\t\treturn true;\n" + ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" + ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" + ret += "\t\treturn true;\n" + ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" + ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' + ret += "\treturn false;\n" + ret += "}\n" + continue + + match typ: + case Bitfield(): + ret += "\n" + all_the_same = all( + val.ver == [*typ.names.values()][0].ver + for val in typ.names.values() + ) + if ( + all_the_same + and (len(typ.bits) == typ.static_size * 8) + and all(typ.bitname_is_valid(bitname) for bitname in typ.bits) + ): + ret += f"\treturn validate_{typ.static_size}(ctx));\n" + else: + ret += f"\t if (validate_{typ.static_size}(ctx))\n" + ret += "\t\treturn true;\n" + if all_the_same: + ret += ( + f"\tstatic const {c_typename(idprefix, typ)} mask = 0b" + + "".join( + "1" if typ.bitname_is_valid(bitname) else "0" + for bitname in reversed(typ.bits) + ) + + ";\n" + ) + else: + ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_verenum(idprefix, 'NUM')}] = {{\n" + verwidth = max(len(ver) for ver in versions) + for ver in sorted(versions): + ret += ( + f"\t\t[{c_verenum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b" + + "".join( + "1" if typ.bitname_is_valid(bitname, ver) else "0" + for bitname in reversed(typ.bits) + ) + + ",\n" + ) + ret += "\t};\n" + ret += f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n" + ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += f"\tif (val & ~mask)\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n' + ret += "\t\t val & ~mask);\n" + ret += "\treturn false;\n" + case Struct(): + if len(typ.members) == 0: + ret += "\n\treturn false;\n" + ret += "}\n" + continue + + prefix0 = "\treturn " + prefix1 = "\t || " + + struct_versions = typ.members[0].ver + + prefix = prefix0 + prev_size: int | None = None + for member in typ.members: + ret += f"\n{prefix}" + if member.ver != struct_versions: + ret += "( " + c_vercond(idprefix, member.ver) + " && " + if member.cnt is not None: + assert prev_size + maxelem = 0 + if ( + typ.name in ["Twalk", "Rwalk"] and member.name[:1] == "w" + ): # SPECIAL + maxelem = 16 + ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), {maxelem}, validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" + else: + ret += f"validate_{member.typ.name}(ctx)" + if member.ver != struct_versions: + ret += " )" + prefix = prefix1 + prev_size = member.static_size + ret += ";\n" + ret += "}\n" + + # unmarshal_* ############################################################## + ret += """ +/* unmarshal_* ****************************************************************/ + +static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { + *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 1; +} + +static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { + *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 2; +} + +static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { + *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 4; +} + +static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { + *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 8; +} +""" + for typ in typs: + inline = ( + " FLATTEN" + if (isinstance(typ, Struct) and typ.msgid is not None) + else " ALWAYS_INLINE" + ) + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + ret += "\n" + ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" + match typ: + case Bitfield(): + ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" + case Struct(): + ret += "\tmemset(out, 0, sizeof(*out));\n" + + if typ.members: + struct_versions = typ.members[0].ver + for member in typ.members: + ret += "\t" + prefix = "\t" + if member.ver != struct_versions: + ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " + prefix = "\t\t" + if member.cnt: + if member.ver != struct_versions: + ret += "{\n" + ret += f"{prefix}out->{member.name} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" + ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + if member.ver != struct_versions: + ret += "\t}\n" + else: + ret += ( + f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" + ) + ret += "}\n" + + # marshal_* ################################################################ + ret += """ +/* marshal_* ******************************************************************/ + +static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { + lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", + (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", + ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), + ctx->ctx->max_msg_size); + return true; +} + +static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { + if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) + return _marshal_too_large(ctx); + ctx->net_bytes[ctx->net_offset] = *val; + ctx->net_offset += 1; + return false; +} + +static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { + if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) + return _marshal_too_large(ctx); + encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 2; + return false; +} + +static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { + if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) + return true; + encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 4; + return false; +} + +static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { + if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) + return true; + encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 8; + return false; +} +""" + for typ in typs: + inline = ( + " FLATTEN" + if (isinstance(typ, Struct) and typ.msgid is not None) + else " ALWAYS_INLINE" + ) + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + ret += "\n" + ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" + match typ: + case Bitfield(): + ret += "\n" + ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" + case Struct(): + if len(typ.members) == 0: + ret += "\n\treturn false;\n" + ret += "}\n" + continue + + prefix0 = "\treturn " + prefix1 = "\t || " + prefix2 = "\t " + + struct_versions = typ.members[0].ver + prefix = prefix0 + for member in typ.members: + ret += f"\n{prefix}" + if member.ver != struct_versions: + ret += "( " + c_vercond(idprefix, member.ver) + " && " + if member.cnt: + ret += "({" + ret += f"\n{prefix2}\tbool err = false;" + ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" + ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" + ret += f"\n{prefix2}\terr;" + ret += f"\n{prefix2}}})" + else: + ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + if member.ver != struct_versions: + ret += " )" + prefix = prefix1 + ret += ";\n" + ret += "}\n" + + # vtables ################################################################## + ret += f""" +/* vtables ********************************************************************/ + +#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ + .basesize = sizeof(struct {idprefix}msg_##typ), \\ + .validate = validate_##typ, \\ + .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ + .marshal = (_marshal_fn_t)marshal_##typ, \\ + }} + +struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{ +""" + + ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n" + for msg in just_structs_msg(typs): + if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL + ret += f"\t\t_MSG({msg.name}),\n" + ret += "\t}},\n" + + for ver in sorted(versions): + ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n" + for msg in just_structs_msg(typs): + if ver not in msg.msgver: + continue + ret += f"\t\t_MSG({msg.name}),\n" + ret += "\t}},\n" + ret += "};\n" + + ############################################################################ + return ret + + +################################################################################ + + +class Parser: + cache: dict[str, tuple[str, list[Bitfield | Struct]]] = {} + + def parse_file(self, filename: str) -> tuple[str, list[Bitfield | Struct]]: + filename = os.path.normpath(filename) + if filename not in self.cache: + + def get_include(other_filename: str) -> tuple[str, list[Bitfield | Struct]]: + return self.parse_file(os.path.join(filename, "..", other_filename)) + + self.cache[filename] = parse_file(filename, get_include) + return self.cache[filename] + + def all(self) -> tuple[set[str], list[Bitfield | Struct]]: + ret_versions: set[str] = set() + ret_typs: dict[str, Bitfield | Struct] = {} + for version, typs in self.cache.values(): + if version in ret_versions: + raise ValueError(f"duplicate protocol version {repr(version)}") + ret_versions.add(version) + for typ in typs: + if typ.name in ret_typs: + if typ != ret_typs[typ.name]: + raise ValueError(f"duplicate type name {repr(typ.name)}") + else: + ret_typs[typ.name] = typ + return ret_versions, list(ret_typs.values()) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + raise ValueError("requires at least 1 .txt filename") + parser = Parser() + for txtname in sys.argv[1:]: + parser.parse_file(txtname) + versions, typs = parser.all() + outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) + with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: + fh.write(gen_h("lib9p_", versions, typs)) + with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: + fh.write(gen_c("lib9p_", versions, typs)) |