summaryrefslogtreecommitdiff
path: root/lib9p/9p.gen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/9p.gen')
-rwxr-xr-xlib9p/9p.gen911
1 files changed, 911 insertions, 0 deletions
diff --git a/lib9p/9p.gen b/lib9p/9p.gen
new file mode 100755
index 0000000..f974dd1
--- /dev/null
+++ b/lib9p/9p.gen
@@ -0,0 +1,911 @@
+#!/usr/bin/env python
+# lib9p/9p.gen - Generate C marshalers/unmarshalers for .txt files
+# defining 9P protocol variants.
+#
+# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-Licence-Identifier: AGPL-3.0-or-later
+
+import enum
+import os.path
+import re
+from typing import Callable, Sequence
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.txt inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+# Parse *.txt ##################################################################
+
+
+class Atom(enum.Enum):
+ u8 = 1
+ u16 = 2
+ u32 = 4
+ u64 = 8
+
+ @property
+ def name(self) -> str:
+ return str(self.value)
+
+ @property
+ def static_size(self) -> int:
+ return self.value
+
+
+class BitfieldVal:
+ name: str
+ val: str
+ ver: set[str]
+
+ def __init__(self) -> None:
+ self.ver = set()
+
+
+class Bitfield:
+ name: str
+ bits: list[str]
+ names: dict[str, BitfieldVal]
+
+ @property
+ def static_size(self) -> int:
+ return int((len(self.bits) + 7) / 8)
+
+ def bitname_is_valid(self, bitname: str, ver: str | None = None) -> bool:
+ assert bitname in self.bits
+ if not bitname:
+ return False
+ if bitname.startswith("_"):
+ return False
+ if ver and (ver not in self.names[bitname].ver):
+ return False
+ return True
+
+
+# `msgid/structname = "member1 member2..."`
+# `structname = "member1 member2..."`
+# `structname += "member1 member2..."`
+class Struct:
+ msgid: int | None = None
+ msgver: set[str]
+ name: str
+ members: list["Member"]
+
+ def __init__(self) -> None:
+ self.msgver = set()
+
+ @property
+ def static_size(self) -> int | None:
+ size = 0
+ for member in self.members:
+ msize = member.static_size
+ if msize is None:
+ return None
+ size += msize
+ return size
+
+
+# `cnt*(name[typ])`
+# the `cnt*(...)` wrapper is optional
+class Member:
+ cnt: str | None = None
+ name: str
+ typ: Atom | Bitfield | Struct
+ ver: set[str]
+
+ @property
+ def static_size(self) -> int | None:
+ if self.cnt:
+ return None
+ return self.typ.static_size
+
+
+re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)"
+re_memberspec = (
+ f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>.*)\\]\\)?"
+)
+
+
+def parse_members(
+ ver: str,
+ env: dict[str, Atom | Bitfield | Struct],
+ existing: list[Member],
+ specs: str,
+) -> list[Member]:
+ ret = existing
+ for spec in specs.split():
+ m = re.fullmatch(re_memberspec, spec)
+ if not m:
+ raise SyntaxError(f"invalid member spec {repr(spec)}")
+
+ member = Member()
+ member.ver = {ver}
+
+ member.name = m.group("name")
+ if any(x.name == member.name for x in ret):
+ raise ValueError(f"duplicate member name {repr(member.name)}")
+
+ if m.group("typ") not in env:
+ raise NameError(f"Unknown type {repr(m.group(2))}")
+ member.typ = env[m.group("typ")]
+
+ if cnt := m.group("cnt"):
+ if len(ret) == 0 or ret[-1].name != cnt:
+ raise ValueError(f"list count must be previous item: {repr(cnt)}")
+ if not isinstance(ret[-1].typ, Atom):
+ raise ValueError(f"list count must be an integer type: {repr(cnt)}")
+ member.cnt = cnt
+
+ ret += [member]
+ return ret
+
+
+re_version = r'version\s+"(?P<version>[^"]+)"'
+re_import = r"from\s+(?P<file>\S+)\s+import\s+(?P<syms>\S+(?:\s*,\s*\S+)*)\s*"
+re_structspec = (
+ r'(?:(?P<msgid>[0-9]+)/)?(?P<name>\S+)\s*(?P<op>\+?=)\s*"(?P<members>[^"]*)"'
+)
+re_structspec_cont = r'\s+"(?P<members>[^"]*)"'
+re_bitfieldspec = r"bitfield\s+(?P<name>\S+)\s+(?P<size>[0-9]+)"
+re_bitfieldspec_bit = r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<bit>[0-9]+)/(?P<name>\S+)"
+re_bitfieldspec_alias = (
+ r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<name>\S+)\s*=\s*(?P<val>.*)"
+)
+
+
+def parse_file(
+ filename: str, get_include: Callable[[str], tuple[str, list[Bitfield | Struct]]]
+) -> tuple[str, list[Bitfield | Struct]]:
+ version: str | None = None
+ env: dict[str, Atom | Bitfield | Struct] = {
+ "1": Atom.u8,
+ "2": Atom.u16,
+ "4": Atom.u32,
+ "8": Atom.u64,
+ }
+ with open(filename, "r") as fh:
+ prev: Struct | Bitfield | None = None
+ for line in fh:
+ line = line.split("#", 1)[0].rstrip()
+ if not line:
+ continue
+ if m := re.fullmatch(re_version, line):
+ if version:
+ raise SyntaxError("must have exactly 1 version line")
+ version = m.group("version")
+ elif m := re.fullmatch(re_import, line):
+ if not version:
+ raise SyntaxError("must have exactly 1 version line")
+ other_version, other_typs = get_include(m.group("file"))
+ for symname in m.group("syms").split(sep=","):
+ symname = symname.strip()
+ for typ in other_typs:
+ if typ.name == symname or symname == "*":
+ match typ:
+ case Bitfield():
+ for val in typ.names.values():
+ if other_version in val.ver:
+ val.ver.add(version)
+ case Struct():
+ if typ.msgid:
+ typ.msgver.add(version)
+ for member in typ.members:
+ if other_version in member.ver:
+ member.ver.add(version)
+ env[typ.name] = typ
+ elif m := re.fullmatch(re_structspec, line):
+ if not version:
+ raise SyntaxError("must have exactly 1 version line")
+ if m.group("op") == "+=" and m.group("msgid"):
+ raise SyntaxError("cannot += to a message that is not yet defined")
+ match m.group("op"):
+ case "=":
+ struct = Struct()
+ if m.group("msgid"):
+ struct.msgid = int(m.group("msgid"))
+ struct.msgver.add(version)
+ struct.name = m.group("name")
+ struct.members = parse_members(
+ version, env, [], m.group("members")
+ )
+ env[struct.name] = struct
+ prev = struct
+ case "+=":
+ if m.group("name") not in env:
+ raise NameError(f"Unknown type {repr(m.group('name'))}")
+ _struct = env[m.group("name")]
+ if not isinstance(_struct, Struct):
+ raise NameError(
+ f"Type {repr(_struct.name)} is not a struct"
+ )
+ struct = _struct
+ struct.members = parse_members(
+ version, env, struct.members, m.group("members")
+ )
+ prev = struct
+ elif m := re.fullmatch(re_structspec_cont, line):
+ if not isinstance(prev, Struct):
+ raise SyntaxError(
+ "struct-continuation line must come after a struct line"
+ )
+ assert version
+ prev.members = parse_members(
+ version, env, prev.members, m.group("members")
+ )
+ elif m := re.fullmatch(re_bitfieldspec, line):
+ if not version:
+ raise SyntaxError("must have exactly 1 version line")
+ bf = Bitfield()
+ bf.name = m.group("name")
+ bf.bits = int(m.group("size")) * [""]
+ bf.names = {}
+ if len(bf.bits) not in [8, 16, 32, 64]:
+ raise ValueError(f"Bitfield {repr(bf.name)} has an unusual size")
+ env[bf.name] = bf
+ prev = bf
+ elif m := re.fullmatch(re_bitfieldspec_bit, line):
+ if m.group("bitfield"):
+ if m.group("bitfield") not in env:
+ raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}")
+ _bf = env[m.group("bitfield")]
+ if not isinstance(_bf, Bitfield):
+ raise NameError(f"Type {repr(_bf.name)} is not a bitfield")
+ bf = _bf
+ prev = bf
+ else:
+ if not isinstance(prev, Bitfield):
+ raise SyntaxError(
+ "bitfield-continuation line must come after a bitfield line"
+ )
+ bf = prev
+ bit = int(m.group("bit"))
+ name = m.group("name")
+ if bit < 0 or bit >= len(bf.bits):
+ raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
+ if bf.bits[bit]:
+ raise ValueError(f"{bf.name}: bit {bit} already assigned")
+ if name in bf.names:
+ raise ValueError(f"{bf.name}: name {name} already assigned")
+
+ bf.bits[bit] = name
+
+ assert version
+ val = BitfieldVal()
+ val.name = name
+ val.val = f"1<<{bit}"
+ val.ver.add(version)
+ bf.names[name] = val
+ elif m := re.fullmatch(re_bitfieldspec_alias, line):
+ if m.group("bitfield"):
+ if m.group("bitfield") not in env:
+ raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}")
+ _bf = env[m.group("bitfield")]
+ if not isinstance(_bf, Bitfield):
+ raise NameError(f"Type {repr(_bf.name)} is not a bitfield")
+ bf = _bf
+ prev = bf
+ else:
+ if not isinstance(prev, Bitfield):
+ raise SyntaxError(
+ "bitfield-continuation line must come after a bitfield line"
+ )
+ bf = prev
+ name = m.group("name")
+ valstr = m.group("val")
+ if name in bf.names:
+ raise ValueError(f"{bf.name}: name {name} already assigned")
+
+ assert version
+ val = BitfieldVal()
+ val.name = name
+ val.val = valstr
+ val.ver.add(version)
+ bf.names[name] = val
+ else:
+ raise SyntaxError(f"invalid line {repr(line)}")
+ if not version:
+ raise SyntaxError("must have exactly 1 version line")
+
+ typs = [x for x in env.values() if not isinstance(x, Atom)]
+ return version, typs
+
+
+# Generate C ###################################################################
+
+
+def c_typename(idprefix: str, typ: Atom | Bitfield | Struct) -> str:
+ match typ:
+ case Atom():
+ return f"uint{typ.value*8}_t"
+ case Bitfield():
+ return f"{idprefix}{typ.name}_t"
+ case Struct():
+ if typ.msgid is not None:
+ return f"struct {idprefix}msg_{typ.name}"
+ return f"struct {idprefix}{typ.name}"
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
+def c_verenum(idprefix: str, ver: str) -> str:
+ return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
+
+
+def c_vercomment(versions: set[str]) -> str | None:
+ if "9P2000" in versions:
+ return None
+ return "/* " + (", ".join(sorted(versions))) + " */"
+
+
+def c_vercond(idprefix: str, versions: set[str]) -> str:
+ if len(versions) == 1:
+ return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})"
+ return (
+ "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )"
+ )
+
+
+def just_structs_all(typs: list[Bitfield | Struct]) -> Sequence[Struct]:
+ return list(typ for typ in typs if isinstance(typ, Struct))
+
+
+def just_structs_nonmsg(typs: list[Bitfield | Struct]) -> Sequence[Struct]:
+ return list(typ for typ in typs if isinstance(typ, Struct) and typ.msgid is None)
+
+
+def just_structs_msg(typs: list[Bitfield | Struct]) -> Sequence[Struct]:
+ return list(
+ typ for typ in typs if isinstance(typ, Struct) and typ.msgid is not None
+ )
+
+
+def just_bitfields(typs: list[Bitfield | Struct]) -> Sequence[Bitfield]:
+ return list(typ for typ in typs if isinstance(typ, Bitfield))
+
+
+def gen_h(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str:
+ ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
+
+#ifndef _LIB9P_9P_H_
+# error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
+#endif
+
+#include <stdint.h> /* for uint{{n}}_t types */
+"""
+
+ ret += f"""
+/* versions *******************************************************************/
+
+enum {idprefix}version {{
+"""
+ fullversions = ["unknown = 0", *sorted(versions)]
+ verwidth = max(len(v) for v in fullversions)
+ for ver in fullversions:
+ ret += f"\t{c_verenum(idprefix, ver)},"
+ ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
+ ret += f"\t{c_verenum(idprefix, 'NUM')},\n"
+ ret += "};\n"
+ ret += "\n"
+ ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"
+
+ ret += """
+/* non-message types **********************************************************/
+"""
+ for bf in just_bitfields(typs):
+ ret += "\n"
+ ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n"
+ names = [
+ *reversed([bf.bits[n] or f"_UNUSED_{n}" for n in range(0, len(bf.bits))]),
+ *[k for k in bf.names if k not in bf.bits],
+ ]
+ namewidth = max(len(name) for name in names)
+
+ for name in names:
+ if name.startswith("_"):
+ cname = f"_{idprefix.upper()}{bf.name.upper()}_{name[1:]}"
+ else:
+ cname = f"{idprefix.upper()}{bf.name.upper()}_{name}"
+ if name in bf.names:
+ val = bf.names[name].val
+ else:
+ assert name.startswith("_UNUSED_")
+ val = f"1<<{name[len('_UNUSED_'):]}"
+ ret += f"#define {cname}{' '*(namewidth-len(name))} (({c_typename(idprefix, bf)})({val}))"
+ if (name in bf.names) and (comment := c_vercomment(bf.names[name].ver)):
+ ret += " " + comment
+ ret += "\n"
+
+ for struct in just_structs_nonmsg(typs):
+ all_the_same = len(struct.members) == 0 or all(
+ m.ver == struct.members[0].ver for m in struct.members
+ )
+ typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members)
+ if not all_the_same:
+ namewidth = max(len(m.name) for m in struct.members)
+
+ ret += "\n"
+ ret += c_typename(idprefix, struct) + " {\n"
+ for member in struct.members:
+ ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};"
+ if (not all_the_same) and (comment := c_vercomment(member.ver)):
+ ret += (" " * (namewidth - len(member.name))) + " " + comment
+ ret += "\n"
+ ret += "};\n"
+
+ ret += """
+/* messages *******************************************************************/
+
+"""
+ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
+ namewidth = max(len(msg.name) for msg in just_structs_msg(typs))
+ for msg in just_structs_msg(typs):
+ ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},"
+ if comment := c_vercomment(msg.msgver):
+ ret += " " + comment
+ ret += "\n"
+ ret += "};\n"
+ ret += "\n"
+ ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n"
+
+ for msg in just_structs_msg(typs):
+ ret += "\n"
+ if comment := c_vercomment(msg.msgver):
+ ret += comment + "\n"
+ ret += c_typename(idprefix, msg) + " {"
+ if not msg.members:
+ ret += "};\n"
+ continue
+ ret += "\n"
+
+ all_the_same = len(msg.members) == 0 or all(
+ m.ver == msg.members[0].ver for m in msg.members
+ )
+ typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members)
+ if not all_the_same:
+ namewidth = max(len(m.name) for m in msg.members)
+
+ for member in msg.members:
+ ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};"
+ if (not all_the_same) and (comment := c_vercomment(member.ver)):
+ ret += (" " * (namewidth - len(member.name))) + " " + comment
+ ret += "\n"
+ ret += "};\n"
+
+ return ret
+
+
+def gen_c(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str:
+ ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
+
+#include <assert.h>
+#include <stdbool.h>
+#include <stddef.h> /* for size_t */
+#include <inttypes.h> /* for PRI* macros */
+#include <string.h> /* for memset() */
+
+#include <lib9p/9p.h>
+
+#include "internal.h"
+"""
+
+ def used(arg: str) -> str:
+ return arg
+
+ def unused(arg: str) -> str:
+ return f"UNUSED({arg})"
+
+ # strings ##################################################################
+ ret += f"""
+/* strings ********************************************************************/
+
+static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{
+"""
+ for ver in ["unknown", *sorted(versions)]:
+ ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n'
+ ret += "};\n"
+ ret += f"""
+const char *{idprefix}version_str(enum {idprefix}version ver) {{
+ assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')});
+ return version_strs[ver];
+}}
+
+static const char *msg_type_strs[0x100] = {{
+"""
+ id2name: dict[int, str] = {}
+ for msg in just_structs_msg(typs):
+ assert msg.msgid
+ id2name[msg.msgid] = msg.name
+ for n in range(0, 0x100):
+ ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n)))
+ ret += "};\n"
+ ret += f"""
+const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{
+ assert(0 <= typ && typ <= 0xFF);
+ return msg_type_strs[typ];
+}}
+"""
+
+ # validate_* ###############################################################
+ ret += """
+/* validate_* *****************************************************************/
+
+static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
+ if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
+ /* If needed-net-size overflowed uint32_t, then
+ * there's no way that actual-net-size will live up to
+ * that. */
+ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+ if (ctx->net_offset > ctx->net_size)
+ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+ return false;
+}
+
+static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
+ if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
+ /* If needed-host-size overflowed size_t, then there's
+ * no way that actual-net-size will live up to
+ * that. */
+ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+ return false;
+}
+
+static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
+ size_t cnt, size_t max,
+ _validate_fn_t item_fn, size_t item_host_size) {
+ if (max && cnt > max)
+ return lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "list size is too large (%zu > %zu)",
+ cnt, max);
+ for (size_t i = 0; i < cnt; i++)
+ if (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
+ return true;
+ return false;
+}
+
+#define validate_1(ctx) _validate_size_net(ctx, 1)
+#define validate_2(ctx) _validate_size_net(ctx, 2)
+#define validate_4(ctx) _validate_size_net(ctx, 4)
+#define validate_8(ctx) _validate_size_net(ctx, 8)
+"""
+ for typ in typs:
+ inline = (
+ " FLATTEN"
+ if (isinstance(typ, Struct) and typ.msgid is not None)
+ else " ALWAYS_INLINE"
+ )
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ ret += "\n"
+ ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{"
+
+ if typ.name == "d": # SPECIAL
+ # Optimize... maybe the compiler could figure out to do
+ # this, but let's make it obvious.
+ ret += "\n"
+ ret += "\tuint32_t base_offset = ctx->net_offset;\n"
+ ret += "\tif (validate_4(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
+ ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
+ ret += "}\n"
+ continue
+ if typ.name == "s": # SPECIAL
+ # Add an extra nul-byte on the host, and validate UTF-8
+ # (also, similar optimization to "d").
+ ret += "\n"
+ ret += "\tuint32_t base_offset = ctx->net_offset;\n"
+ ret += "\tif (validate_2(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
+ ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
+ ret += "\t\treturn true;\n"
+ ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
+ ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ match typ:
+ case Bitfield():
+ ret += "\n"
+ all_the_same = all(
+ val.ver == [*typ.names.values()][0].ver
+ for val in typ.names.values()
+ )
+ if (
+ all_the_same
+ and (len(typ.bits) == typ.static_size * 8)
+ and all(typ.bitname_is_valid(bitname) for bitname in typ.bits)
+ ):
+ ret += f"\treturn validate_{typ.static_size}(ctx));\n"
+ else:
+ ret += f"\t if (validate_{typ.static_size}(ctx))\n"
+ ret += "\t\treturn true;\n"
+ if all_the_same:
+ ret += (
+ f"\tstatic const {c_typename(idprefix, typ)} mask = 0b"
+ + "".join(
+ "1" if typ.bitname_is_valid(bitname) else "0"
+ for bitname in reversed(typ.bits)
+ )
+ + ";\n"
+ )
+ else:
+ ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_verenum(idprefix, 'NUM')}] = {{\n"
+ verwidth = max(len(ver) for ver in versions)
+ for ver in sorted(versions):
+ ret += (
+ f"\t\t[{c_verenum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b"
+ + "".join(
+ "1" if typ.bitname_is_valid(bitname, ver) else "0"
+ for bitname in reversed(typ.bits)
+ )
+ + ",\n"
+ )
+ ret += "\t};\n"
+ ret += f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n"
+ ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ ret += f"\tif (val & ~mask)\n"
+ ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n'
+ ret += "\t\t val & ~mask);\n"
+ ret += "\treturn false;\n"
+ case Struct():
+ if len(typ.members) == 0:
+ ret += "\n\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ prefix0 = "\treturn "
+ prefix1 = "\t || "
+
+ struct_versions = typ.members[0].ver
+
+ prefix = prefix0
+ prev_size: int | None = None
+ for member in typ.members:
+ ret += f"\n{prefix}"
+ if member.ver != struct_versions:
+ ret += "( " + c_vercond(idprefix, member.ver) + " && "
+ if member.cnt is not None:
+ assert prev_size
+ maxelem = 0
+ if (
+ typ.name in ["Twalk", "Rwalk"] and member.name[:1] == "w"
+ ): # SPECIAL
+ maxelem = 16
+ ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), {maxelem}, validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
+ else:
+ ret += f"validate_{member.typ.name}(ctx)"
+ if member.ver != struct_versions:
+ ret += " )"
+ prefix = prefix1
+ prev_size = member.static_size
+ ret += ";\n"
+ ret += "}\n"
+
+ # unmarshal_* ##############################################################
+ ret += """
+/* unmarshal_* ****************************************************************/
+
+static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
+ *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 1;
+}
+
+static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
+ *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 2;
+}
+
+static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
+ *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 4;
+}
+
+static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
+ *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 8;
+}
+"""
+ for typ in typs:
+ inline = (
+ " FLATTEN"
+ if (isinstance(typ, Struct) and typ.msgid is not None)
+ else " ALWAYS_INLINE"
+ )
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ ret += "\n"
+ ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n"
+ match typ:
+ case Bitfield():
+ ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n"
+ case Struct():
+ ret += "\tmemset(out, 0, sizeof(*out));\n"
+
+ if typ.members:
+ struct_versions = typ.members[0].ver
+ for member in typ.members:
+ ret += "\t"
+ prefix = "\t"
+ if member.ver != struct_versions:
+ ret += "if ( " + c_vercond(idprefix, member.ver) + " ) "
+ prefix = "\t\t"
+ if member.cnt:
+ if member.ver != struct_versions:
+ ret += "{\n"
+ ret += f"{prefix}out->{member.name} = ctx->extra;\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
+ ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
+ if member.ver != struct_versions:
+ ret += "\t}\n"
+ else:
+ ret += (
+ f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
+ )
+ ret += "}\n"
+
+ # marshal_* ################################################################
+ ret += """
+/* marshal_* ******************************************************************/
+
+static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) {
+ lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
+ (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
+ ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
+ ctx->ctx->max_msg_size);
+ return true;
+}
+
+static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
+ if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
+ return _marshal_too_large(ctx);
+ ctx->net_bytes[ctx->net_offset] = *val;
+ ctx->net_offset += 1;
+ return false;
+}
+
+static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
+ if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
+ return _marshal_too_large(ctx);
+ encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 2;
+ return false;
+}
+
+static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
+ if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
+ return true;
+ encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 4;
+ return false;
+}
+
+static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
+ if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
+ return true;
+ encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 8;
+ return false;
+}
+"""
+ for typ in typs:
+ inline = (
+ " FLATTEN"
+ if (isinstance(typ, Struct) and typ.msgid is not None)
+ else " ALWAYS_INLINE"
+ )
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ ret += "\n"
+ ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{"
+ match typ:
+ case Bitfield():
+ ret += "\n"
+ ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n"
+ case Struct():
+ if len(typ.members) == 0:
+ ret += "\n\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ prefix0 = "\treturn "
+ prefix1 = "\t || "
+ prefix2 = "\t "
+
+ struct_versions = typ.members[0].ver
+ prefix = prefix0
+ for member in typ.members:
+ ret += f"\n{prefix}"
+ if member.ver != struct_versions:
+ ret += "( " + c_vercond(idprefix, member.ver) + " && "
+ if member.cnt:
+ ret += "({"
+ ret += f"\n{prefix2}\tbool err = false;"
+ ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
+ ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
+ ret += f"\n{prefix2}\terr;"
+ ret += f"\n{prefix2}}})"
+ else:
+ ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
+ if member.ver != struct_versions:
+ ret += " )"
+ prefix = prefix1
+ ret += ";\n"
+ ret += "}\n"
+
+ # vtables ##################################################################
+ ret += f"""
+/* vtables ********************************************************************/
+
+#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\
+ .basesize = sizeof(struct {idprefix}msg_##typ), \\
+ .validate = validate_##typ, \\
+ .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\
+ .marshal = (_marshal_fn_t)marshal_##typ, \\
+ }}
+
+struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{
+"""
+
+ ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n"
+ for msg in just_structs_msg(typs):
+ if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL
+ ret += f"\t\t_MSG({msg.name}),\n"
+ ret += "\t}},\n"
+
+ for ver in sorted(versions):
+ ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n"
+ for msg in just_structs_msg(typs):
+ if ver not in msg.msgver:
+ continue
+ ret += f"\t\t_MSG({msg.name}),\n"
+ ret += "\t}},\n"
+ ret += "};\n"
+
+ ############################################################################
+ return ret
+
+
+################################################################################
+
+
+class Parser:
+ cache: dict[str, tuple[str, list[Bitfield | Struct]]] = {}
+
+ def parse_file(self, filename: str) -> tuple[str, list[Bitfield | Struct]]:
+ filename = os.path.normpath(filename)
+ if filename not in self.cache:
+
+ def get_include(other_filename: str) -> tuple[str, list[Bitfield | Struct]]:
+ return self.parse_file(os.path.join(filename, "..", other_filename))
+
+ self.cache[filename] = parse_file(filename, get_include)
+ return self.cache[filename]
+
+ def all(self) -> tuple[set[str], list[Bitfield | Struct]]:
+ ret_versions: set[str] = set()
+ ret_typs: dict[str, Bitfield | Struct] = {}
+ for version, typs in self.cache.values():
+ if version in ret_versions:
+ raise ValueError(f"duplicate protocol version {repr(version)}")
+ ret_versions.add(version)
+ for typ in typs:
+ if typ.name in ret_typs:
+ if typ != ret_typs[typ.name]:
+ raise ValueError(f"duplicate type name {repr(typ.name)}")
+ else:
+ ret_typs[typ.name] = typ
+ return ret_versions, list(ret_typs.values())
+
+
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 2:
+ raise ValueError("requires at least 1 .txt filename")
+ parser = Parser()
+ for txtname in sys.argv[1:]:
+ parser.parse_file(txtname)
+ versions, typs = parser.all()
+ outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
+ with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
+ fh.write(gen_h("lib9p_", versions, typs))
+ with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
+ fh.write(gen_c("lib9p_", versions, typs))