summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-10-09 00:00:32 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-10-09 00:00:32 -0600
commitbed78039f9bf086d35a3ae5efc8e6701c50ed006 (patch)
tree38fd0e196f08673f43f984b4228a4b6b5c647431 /lib9p/idl.gen
parent5dfff05adef9b1e09cff350c2dc551fdac6d227a (diff)
wip idl rewrite
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen1180
1 files changed, 1180 insertions, 0 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
new file mode 100755
index 0000000..1dafef9
--- /dev/null
+++ b/lib9p/idl.gen
@@ -0,0 +1,1180 @@
+#!/usr/bin/env python
+# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files
+# defining 9P protocol variants.
+#
+# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-Licence-Identifier: AGPL-3.0-or-later
+
+import enum
+import os.path
+import re
+from abc import ABC, abstractmethod
+from typing import Callable, Literal, TypeAlias, TypeVar
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+T = TypeVar("T")
+
+# Types ########################################################################
+
+Type: TypeAlias = "Primitive | Number | Bitfield | Struct | Message"
+
+
+class Primitive(enum.Enum):
+ u8 = 1
+ u16 = 2
+ u32 = 4
+ u64 = 8
+
+ @property
+ def in_versions(self) -> set[str]:
+ return set()
+
+ @property
+ def name(self) -> str:
+ return str(self.value)
+
+ @property
+ def static_size(self) -> int:
+ return self.value
+
+
+class Number:
+ name: str
+ in_versions: set[str]
+
+ prim: Primitive
+
+ def __init__(self) -> None:
+ self.in_versions = set()
+
+ @property
+ def static_size(self) -> int:
+ return self.prim.static_size
+
+
+class BitfieldVal:
+ name: str
+ in_versions: set[str]
+
+ val: str
+
+ def __init__(self) -> None:
+ self.in_versions = set()
+
+
+class Bitfield:
+ name: str
+ in_versions: set[str]
+
+ prim: Primitive
+
+ bits: list[str] # bitnames
+ names: dict[str, BitfieldVal] # bits *and* aliases
+
+ def __init__(self) -> None:
+ self.in_versions = set()
+ self.names = {}
+
+ @property
+ def static_size(self) -> int:
+ return self.prim.static_size
+
+ def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool:
+ """Return whether the given bit is valid in the given protocol
+ version.
+
+ """
+ bitname = self.bits[bit] if isinstance(bit, int) else bit
+ assert bitname in self.bits
+ if not bitname:
+ return False
+ if bitname.startswith("_"):
+ return False
+ if ver and (ver not in self.names[bitname].in_versions):
+ return False
+ return True
+
+
+class ExprLit:
+ val: int
+
+ def __init__(self, val: int) -> None:
+ self.val = val
+
+
+class ExprSym:
+ name: str
+
+ def __init__(self, name: str) -> None:
+ self.name = name
+
+
+class ExprOp:
+ op: Literal["-", "+"]
+
+ def __init__(self, op: Literal["-", "+"]) -> None:
+ self.op = op
+
+
+class Expr:
+ tokens: list[ExprLit | ExprSym | ExprOp]
+
+ def __init__(self) -> None:
+ self.tokens = []
+
+ def __bool__(self) -> bool:
+ return len(self.tokens) > 0
+
+
+class StructMember:
+ # from left-to-right when parsing
+ cnt: str | None = None
+ name: str
+ typ: Type
+ max: Expr
+ val: Expr
+
+ in_versions: set[str]
+
+ @property
+ def static_size(self) -> int | None:
+ if self.cnt:
+ return None
+ return self.typ.static_size
+
+
+class Struct:
+ name: str
+ in_versions: set[str]
+
+ members: list[StructMember]
+
+ def __init__(self) -> None:
+ self.in_versions = set()
+
+ @property
+ def static_size(self) -> int | None:
+ size = 0
+ for member in self.members:
+ msize = member.static_size
+ if msize is None:
+ return None
+ size += msize
+ return size
+
+
+class Message(Struct):
+ @property
+ def msgid(self) -> int:
+ assert len(self.members) >= 3
+ assert self.members[1].name == "typ"
+ assert self.members[1].static_size == 1
+ assert self.members[1].val
+ assert len(self.members[1].val.tokens) == 1
+ assert isinstance(self.members[1].val.tokens[0], ExprLit)
+ return self.members[1].val.tokens[0].val
+
+
+# Parse *.9p ###################################################################
+
+re_priname = "(?:1|2|4|8)" # primitive names
+re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names
+re_impname = r"(?:\*|" + re_symname + ")" # names we can import
+re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be
+
+re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be
+
+re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)"
+
+re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})"
+re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"
+
+re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?"
+
+
+def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None:
+ spec = spec.strip()
+
+ bit: int | None
+ val: BitfieldVal
+ if m := re.fullmatch(re_bitspec_bit, spec):
+ bit = int(m.group("bit"))
+ name = m.group("name")
+
+ val = BitfieldVal()
+ val.name = name
+ val.val = f"1<<{bit}"
+ val.in_versions.add(ver)
+
+ if bit < 0 or bit >= len(bf.bits):
+ raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
+ if bf.bits[bit]:
+ raise ValueError(f"{bf.name}: bit {bit} already assigned")
+ bf.bits[bit] = val.name
+ elif m := re.fullmatch(re_bitspec_alias, spec):
+ name = m.group("name")
+ valstr = m.group("val")
+
+ val = BitfieldVal()
+ val.name = name
+ val.val = valstr
+ val.in_versions.add(ver)
+ else:
+ raise SyntaxError(f"invalid bitfield spec {repr(spec)}")
+
+ if val.name in bf.names:
+ raise ValueError(f"{bf.name}: name {val.name} already assigned")
+ bf.names[val.name] = val
+
+
+def parse_expr(expr: str) -> Expr:
+ assert re.fullmatch(re_expr, expr)
+ ret = Expr()
+ for tok in re.split("([-+])", expr):
+ if tok == "-" or tok == "+":
+ ret.tokens += [ExprOp(tok)]
+ elif re.fullmatch("[0-9]+", tok):
+ ret.tokens += [ExprLit(int(tok))]
+ else:
+ ret.tokens += [ExprSym(tok)]
+ return ret
+
+
+def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None:
+ for spec in specs.split():
+ m = re.fullmatch(re_memberspec, spec)
+ if not m:
+ raise SyntaxError(f"invalid member spec {repr(spec)}")
+
+ member = StructMember()
+ member.in_versions = {ver}
+
+ member.name = m.group("name")
+ if any(x.name == member.name for x in struct.members):
+ raise ValueError(f"duplicate member name {repr(member.name)}")
+
+ if m.group("typ") not in env:
+ raise NameError(f"Unknown type {repr(m.group(2))}")
+ member.typ = env[m.group("typ")]
+
+ if cnt := m.group("cnt"):
+ if len(struct.members) == 0 or struct.members[-1].name != cnt:
+ raise ValueError(f"list count must be previous item: {repr(cnt)}")
+ if not isinstance(struct.members[-1].typ, Primitive):
+ raise ValueError(f"list count must be an integer type: {repr(cnt)}")
+ member.cnt = cnt
+
+ if maxstr := m.group("max"):
+ if (not isinstance(member.typ, Primitive)) or member.cnt:
+ raise ValueError("',max=' may only be specified on a non-repeated atom")
+ member.max = parse_expr(maxstr)
+ else:
+ member.max = Expr()
+
+ if valstr := m.group("val"):
+ if (not isinstance(member.typ, Primitive)) or member.cnt:
+ raise ValueError("',val=' may only be specified on a non-repeated atom")
+ member.val = parse_expr(valstr)
+ else:
+ member.val = Expr()
+
+ struct.members += [member]
+
+
+def re_string(grpname: str) -> str:
+ return f'"(?P<{grpname}>[^"]*)"'
+
+
+re_line_version = f"version\\s+{re_string('version')}"
+re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)"
+re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
+re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
+re_line_bitfield_ = f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}"
+re_line_struct = (
+ f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
+)
+re_line_msg = (
+ f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
+)
+re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg
+
+
+def parse_file(
+ filename: str, get_include: Callable[[str], tuple[str, list[Type]]]
+) -> tuple[str, list[Type]]:
+ version: str | None = None
+ env: dict[str, Type] = {
+ "1": Primitive.u8,
+ "2": Primitive.u16,
+ "4": Primitive.u32,
+ "8": Primitive.u64,
+ }
+
+ def get_type(name: str, tc: type[T]) -> T:
+ nonlocal env
+ if name not in env:
+ raise NameError(f"Unknown type {repr(name)}")
+ ret = env[name]
+ if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__):
+ raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}")
+ return ret
+
+ with open(filename, "r") as fh:
+ prev: Type | None = None
+ for line in fh:
+ line = line.split("#", 1)[0].rstrip()
+ if not line:
+ continue
+ if m := re.fullmatch(re_line_version, line):
+ if version:
+ raise SyntaxError("must have exactly 1 version line")
+ version = m.group("version")
+ continue
+ if not version:
+ raise SyntaxError("must have exactly 1 version line")
+
+ if m := re.fullmatch(re_line_import, line):
+ other_version, other_typs = get_include(m.group("file"))
+ for symname in m.group("syms").split(sep=","):
+ symname = symname.strip()
+ for typ in other_typs:
+ if typ.name == symname or symname == "*":
+ match typ:
+ case Primitive():
+ pass
+ case Number():
+ typ.in_versions.add(version)
+ case Bitfield():
+ typ.in_versions.add(version)
+ for val in typ.names.values():
+ if other_version in val.in_versions:
+ val.in_versions.add(version)
+ case Struct(): # and Message()
+ typ.in_versions.add(version)
+ for member in typ.members:
+ if other_version in member.in_versions:
+ member.in_versions.add(version)
+ env[typ.name] = typ
+ elif m := re.fullmatch(re_line_num, line):
+ num = Number()
+ num.name = m.group("name")
+ num.in_versions.add(version)
+
+ prim = env[m.group("prim")]
+ assert isinstance(prim, Primitive)
+ num.prim = prim
+
+ env[num.name] = num
+ prev = num
+ elif m := re.fullmatch(re_line_bitfield, line):
+ bf = Bitfield()
+ bf.name = m.group("name")
+
+ prim = env[m.group("prim")]
+ assert isinstance(prim, Primitive)
+ bf.prim = prim
+
+ bf.bits = (prim.static_size * 8) * [""]
+
+ env[bf.name] = bf
+ prev = bf
+ elif m := re.fullmatch(re_line_bitfield_, line):
+ bf = get_type(m.group("name"), Bitfield)
+ parse_bitspec(version, bf, m.group("member"))
+
+ prev = bf
+ elif m := re.fullmatch(re_line_struct, line):
+ match m.group("op"):
+ case "=":
+ struct = Struct()
+ struct.name = m.group("name")
+ struct.in_versions.add(version)
+ struct.members = []
+ parse_members(version, env, struct, m.group("members"))
+
+ env[struct.name] = struct
+ prev = struct
+ case "+=":
+ struct = get_type(m.group("name"), Struct)
+ parse_members(version, env, struct, m.group("members"))
+
+ prev = struct
+ elif m := re.fullmatch(re_line_msg, line):
+ match m.group("op"):
+ case "=":
+ msg = Message()
+ msg.name = m.group("name")
+ msg.in_versions.add(version)
+ msg.members = []
+ parse_members(version, env, msg, m.group("members"))
+
+ env[msg.name] = msg
+ prev = msg
+ case "+=":
+ msg = get_type(m.group("name"), Message)
+ parse_members(version, env, msg, m.group("members"))
+
+ prev = msg
+ elif m := re.fullmatch(re_line_cont, line):
+ match prev:
+ case Bitfield():
+ parse_bitspec(version, prev, m.group("specs"))
+ case Struct(): # and Message()
+ parse_members(version, env, prev, m.group("specs"))
+ case _:
+ raise SyntaxError(
+ "continuation line must come after a bitfield, struct, or msg line"
+ )
+ else:
+ raise SyntaxError(f"invalid line {repr(line)}")
+ if not version:
+ raise SyntaxError("must have exactly 1 version line")
+
+ typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]
+
+ for typ in [typ for typ in typs if isinstance(typ, Struct)]:
+ valid_syms = ["end", *["&" + m.name for m in typ.members]]
+ for member in typ.members:
+ for tok in [*member.max.tokens, *member.val.tokens]:
+ if isinstance(tok, ExprSym) and tok.name not in valid_syms:
+ raise ValueError(
+ f"{typ.name}.{member.name}: invalid sym: {tok.name}"
+ )
+
+ return version, typs
+
+
+# Generate C ###################################################################
+
+
+def c_ver_enum(idprefix: str, ver: str) -> str:
+ return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
+
+
+def c_ver_ifdef(idprefix: str, versions: set[str]) -> str:
+ return " || ".join(
+ f"defined(CONFIG_{idprefix.upper()}ENABLE_{c_ver_enum('', v)})"
+ for v in sorted(versions)
+ )
+
+
+def c_ver_cond(idprefix: str, versions: set[str]) -> str:
+ if len(versions) == 1:
+ return f"(ctx->ctx->version=={c_ver_enum(idprefix, next(v for v in versions))})"
+ return (
+ "( " + (" || ".join(c_ver_cond(idprefix, {v}) for v in sorted(versions))) + " )"
+ )
+
+
+def c_typename(idprefix: str, typ: Type) -> str:
+ match typ:
+ case Primitive():
+ return f"uint{typ.value*8}_t"
+ case Number():
+ return f"{idprefix}{typ.name}_t"
+ case Bitfield():
+ return f"{idprefix}{typ.name}_t"
+ case Message():
+ return f"struct {idprefix}msg_{typ.name}"
+ case Struct():
+ return f"struct {idprefix}{typ.name}"
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
+def gen_h(idprefix: str, versions: set[str], typs: list[Type]) -> str:
+ ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
+
+#ifndef _LIB9P_9P_H_
+# error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
+#endif
+
+#include <stdint.h> /* for uint{{n}}_t types */
+"""
+
+ _ifdef: list[str] = []
+
+ def push_ifdef(v: str) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ ret += f"#if {v}\n"
+ _ifdef += [v]
+
+ def pop_ifdef(n: int) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ while len(_ifdef) > n:
+ ret += f"#endif /* {_ifdef[-1]}\n"
+ _ifdef = _ifdef[:-1]
+
+ def set_ifdef(v: str) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ if v != _ifdef[-1]:
+ ret += f"#elif {v}\n"
+ _ifdef[-1] = v
+
+ def pushorset_ifdef(n: int, v: str) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ if len(_ifdef) < n:
+ push_ifdef(v)
+ else:
+ set_ifdef(v)
+
+ ret += f"""
+/* versions *******************************************************************/
+
+enum {idprefix}version {{
+"""
+ fullversions = ["unknown = 0", *sorted(versions)]
+ verwidth = max(len(v) for v in fullversions)
+ for ver in fullversions:
+ if ver in versions:
+ pushorset_ifdef(1, c_ver_ifdef(idprefix, {ver}))
+ ret += f"\t{c_ver_enum(idprefix, ver)},"
+ ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
+ pop_ifdef(0)
+ ret += f"\t{c_ver_enum(idprefix, 'NUM')},\n"
+ ret += "};\n"
+ ret += "\n"
+ ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"
+
+ ret += """
+/* non-message types **********************************************************/
+"""
+ for typ in [typ for typ in typs if not isinstance(typ, Message)]:
+ ret += "\n"
+ pushorset_ifdef(1, c_ver_ifdef(idprefix, typ.in_versions))
+ match typ:
+ case Number():
+ ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n"
+ case Bitfield():
+ ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n"
+ names = [
+ *reversed(
+ [typ.bits[n] or f"_UNUSED_{n}" for n in range(0, len(typ.bits))]
+ ),
+ "",
+ *[k for k in typ.names if k not in typ.bits],
+ ]
+ namewidth = max(len(name) for name in names)
+
+ ret += "\n"
+ for name in names:
+ if name == "":
+ ret += "\n"
+ continue
+ pushorset_ifdef(
+ 2, c_ver_ifdef(idprefix, typ.names[name].in_versions)
+ )
+ if name.startswith("_"):
+ c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}"
+ else:
+ c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}"
+ if name in typ.names:
+ val = typ.names[name].val
+ else:
+ assert name.startswith("_UNUSED_")
+ val = f"1<<{name[len('_UNUSED_'):]}"
+ ret += f"#define {c_name}{' '*(namewidth-len(name))} (({c_typename(idprefix, typ)})({val}))\n"
+ pop_ifdef(1)
+ case Struct():
+ typewidth = max(len(c_typename(idprefix, m.typ)) for m in typ.members)
+
+ ret += c_typename(idprefix, typ) + " {\n"
+ for member in typ.members:
+ if member.val:
+ continue
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions))
+ c_type = c_typename(idprefix, member.typ)
+ if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL
+ c_type = "char"
+ ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ pop_ifdef(1)
+ ret += "};\n"
+ pop_ifdef(0)
+
+ ret += """
+/* messages *******************************************************************/
+
+"""
+ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
+ namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message))
+ for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ pushorset_ifdef(1, c_ver_ifdef(idprefix, msg.in_versions))
+ ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
+ pop_ifdef(0)
+ ret += "};\n"
+ ret += "\n"
+ ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n"
+
+ for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ ret += "\n"
+ pushorset_ifdef(1, c_ver_ifdef(idprefix, msg.in_versions))
+ ret += c_typename(idprefix, msg) + " {"
+ if not msg.members:
+ ret += "};\n"
+ continue
+ ret += "\n"
+
+ typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members)
+
+ for member in msg.members:
+ if member.val:
+ continue
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions))
+ ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ pop_ifdef(1)
+ ret += "};\n"
+ pop_ifdef(0)
+
+ return ret
+
+
+def c_expr(expr: Expr) -> str:
+ ret: list[str] = []
+ for tok in expr.tokens:
+ match tok:
+ case ExprOp():
+ ret += [tok.op]
+ case ExprLit():
+ ret += [str(tok.val)]
+ case ExprSym(name="end"):
+ ret += ["ctx->net_offset"]
+ case ExprSym():
+ ret += [f"_{tok.name[1:]}_offset"]
+ return " ".join(ret)
+
+
+def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str:
+ ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
+
+#include <assert.h>
+#include <stdbool.h>
+#include <stddef.h> /* for size_t */
+#include <inttypes.h> /* for PRI* macros */
+#include <string.h> /* for memset() */
+
+#include <lib9p/9p.h>
+
+#include "internal.h"
+"""
+
+ _ifdef: list[str] = []
+
+ def push_ifdef(v: str) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ ret += f"#if {v}\n"
+ _ifdef += [v]
+
+ def pop_ifdef(n: int) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ while len(_ifdef) > n:
+ ret += f"#endif /* {_ifdef[-1]}\n"
+ _ifdef = _ifdef[:-1]
+
+ def set_ifdef(v: str) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ if v != _ifdef[-1]:
+ ret += f"#elif {v}\n"
+ _ifdef[-1] = v
+
+ def pushorset_ifdef(n: int, v: str) -> None:
+ nonlocal _ifdef
+ nonlocal ret
+ if len(_ifdef) < n:
+ push_ifdef(v)
+ else:
+ set_ifdef(v)
+
+ def used(arg: str) -> str:
+ return arg
+
+ def unused(arg: str) -> str:
+ return f"UNUSED({arg})"
+
+ # strings ##################################################################
+ ret += f"""
+/* strings ********************************************************************/
+
+static const char *version_strs[{c_ver_enum(idprefix, 'NUM')}] = {{
+"""
+ for ver in ["unknown", *sorted(versions)]:
+ pushorset_ifdef(1, c_ver_ifdef(idprefix, {ver}))
+ ret += f'\t[{c_ver_enum(idprefix, ver)}] = "{ver}",\n'
+ pop_ifdef(0)
+ ret += "};\n"
+ ret += f"""
+const char *{idprefix}version_str(enum {idprefix}version ver) {{
+ assert(0 <= ver && ver < {c_ver_enum(idprefix, 'NUM')});
+ return version_strs[ver];
+}}
+
+static const char *msg_type_strs[0x100] = {{
+"""
+ id2name: dict[int, str] = {}
+ for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ id2name[msg.msgid] = msg.name
+ for n in range(0, 0x100):
+ ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n)))
+ ret += "};\n"
+ ret += f"""
+const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{
+ assert(0 <= typ && typ <= 0xFF);
+ return msg_type_strs[typ];
+}}
+"""
+
+ # validate_* ###############################################################
+ ret += """
+/* validate_* *****************************************************************/
+
+static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
+ if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
+ /* If needed-net-size overflowed uint32_t, then
+ * there's no way that actual-net-size will live up to
+ * that. */
+ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+ if (ctx->net_offset > ctx->net_size)
+ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+ return false;
+}
+
+static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
+ if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
+ /* If needed-host-size overflowed size_t, then there's
+ * no way that actual-net-size will live up to
+ * that. */
+ return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+ return false;
+}
+
+static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
+ size_t cnt,
+ _validate_fn_t item_fn, size_t item_host_size) {
+ for (size_t i = 0; i < cnt; i++)
+ if (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
+ return true;
+ return false;
+}
+
+#define validate_1(ctx) _validate_size_net(ctx, 1)
+#define validate_2(ctx) _validate_size_net(ctx, 2)
+#define validate_4(ctx) _validate_size_net(ctx, 4)
+#define validate_8(ctx) _validate_size_net(ctx, 8)
+"""
+ for typ in typs:
+ inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ ret += "\n"
+ pushorset_ifdef(1, c_ver_ifdef(idprefix, typ.in_versions))
+ ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"
+
+ if typ.name == "d": # SPECIAL
+ # Optimize... maybe the compiler could figure out to do
+ # this, but let's make it obvious.
+ ret += "\tuint32_t base_offset = ctx->net_offset;\n"
+ ret += "\tif (validate_4(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
+ ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
+ ret += "}\n"
+ continue
+ if typ.name == "s": # SPECIAL
+ # Add an extra nul-byte on the host, and validate UTF-8
+ # (also, similar optimization to "d").
+ ret += "\tuint32_t base_offset = ctx->net_offset;\n"
+ ret += "\tif (validate_2(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
+ ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
+ ret += "\t\treturn true;\n"
+ ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
+ ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ match typ:
+ case Number():
+ ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
+ case Bitfield():
+ ret += f"\t if (validate_{typ.static_size}(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_ver_enum(idprefix, 'NUM')}] = {{\n"
+ verwidth = max(len(ver) for ver in versions)
+ for ver in sorted(versions):
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, {ver}))
+ ret += (
+ f"\t\t[{c_ver_enum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b"
+ + "".join(
+ "1" if typ.bit_is_valid(bitname, ver) else "0"
+ for bitname in reversed(typ.bits)
+ )
+ + ",\n"
+ )
+ pop_ifdef(1)
+ ret += "\t};\n"
+ ret += (
+ f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n"
+ )
+ ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ ret += f"\tif (val & ~mask)\n"
+ ret += "\t\treturn lib9p_errorf(ctx->ctx,\n"
+ ret += f'\t\t LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n'
+ ret += "\t\t val & ~mask);\n"
+ ret += "\treturn false;\n"
+ case Struct(): # and Message()
+ if len(typ.members) == 0:
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ # Pass 1
+ for member in typ.members:
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions))
+ if member.max or member.val:
+ ret += f"\t{c_typename(idprefix, member.typ)} {member.name};\n"
+ pop_ifdef(1)
+
+ # Pass 2
+ mark_offset: set[str] = set()
+ for member in typ.members:
+ for tok in [*member.max.tokens, *member.val.tokens]:
+ if isinstance(tok, ExprSym) and tok.name.startswith("&"):
+ if tok.name[1:] not in mark_offset:
+ ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
+ mark_offset.add(tok.name[1:])
+
+ # Pass 3
+ ret += "\treturn false\n"
+ prev_size: int | None = None
+ for member in typ.members:
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions))
+ ret += f"\n\t|| "
+ if member.in_versions != typ.in_versions:
+ ret += "( " + c_ver_cond(idprefix, member.in_versions) + " && "
+ if member.cnt is not None:
+ assert prev_size
+ ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
+ else:
+ if member.max or member.val:
+ ret += "("
+ if member.name in mark_offset:
+ ret += f"({{ _{member.name}_offset = ctx->net_offset; "
+ ret += f"validate_{member.typ.name}(ctx)"
+ if member.name in mark_offset:
+ ret += "; })"
+ if member.max or member.val:
+ bytes = member.static_size
+ assert bytes
+ bits = bytes * 8
+ ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))"
+ if member.in_versions != typ.in_versions:
+ ret += " )"
+ prev_size = member.static_size
+
+ # Pass 4
+ for member in typ.members:
+ if member.max:
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions))
+ ret += f"\n\t|| ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
+ ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
+ if member.val:
+ pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions))
+ ret += f"\n\t|| ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
+ ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'
+
+ pop_ifdef(1)
+ ret += "\t;\n"
+ ret += "}\n"
+ pop_ifdef(0)
+
+ # # unmarshal_* ##############################################################
+ # ret += """
+ # /* unmarshal_* ****************************************************************/
+
+ # static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
+ # *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 1;
+ # }
+
+ # static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
+ # *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 2;
+ # }
+
+ # static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
+ # *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 4;
+ # }
+
+ # static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
+ # *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 8;
+ # }
+ # """
+ # for typ in typs:
+ # inline = (
+ # " FLATTEN"
+ # if (isinstance(typ, Struct) and typ.msgid is not None)
+ # else " ALWAYS_INLINE"
+ # )
+ # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ # ret += "\n"
+ # ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n"
+ # match typ:
+ # case Bitfield():
+ # ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n"
+ # case Struct():
+ # ret += "\tmemset(out, 0, sizeof(*out));\n"
+
+ # if typ.members:
+ # struct_versions = typ.members[0].ver
+ # for member in typ.members:
+ # if member.valexpr:
+ # ret += f"\tctx->net_offset += {member.static_size};\n"
+ # continue
+ # ret += "\t"
+ # prefix = "\t"
+ # if member.ver != struct_versions:
+ # ret += "if ( " + c_ver_cond(idprefix, member.ver) + " ) "
+ # prefix = "\t\t"
+ # if member.cnt:
+ # if member.ver != struct_versions:
+ # ret += f"{{\n{prefix}"
+ # ret += f"out->{member.name} = ctx->extra;\n"
+ # ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
+ # ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
+ # if typ.name in ["d", "s"]: # SPECIAL
+ # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
+ # else:
+ # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
+ # if member.ver != struct_versions:
+ # ret += "\t}\n"
+ # else:
+ # ret += (
+ # f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
+ # )
+ # if typ.name == "s": # SPECIAL
+ # ret += "\tctx->extra++;\n"
+ # ret += "\tout->utf8[out->len] = '\\0';\n"
+ # ret += "}\n"
+
+ # # marshal_* ################################################################
+ # ret += """
+ # /* marshal_* ******************************************************************/
+
+ # static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) {
+ # lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
+ # (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
+ # ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
+ # ctx->ctx->max_msg_size);
+ # return true;
+ # }
+
+ # static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
+ # if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
+ # return _marshal_too_large(ctx);
+ # ctx->net_bytes[ctx->net_offset] = *val;
+ # ctx->net_offset += 1;
+ # return false;
+ # }
+
+ # static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
+ # if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
+ # return _marshal_too_large(ctx);
+ # encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 2;
+ # return false;
+ # }
+
+ # static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
+ # if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
+ # return true;
+ # encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 4;
+ # return false;
+ # }
+
+ # static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
+ # if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
+ # return true;
+ # encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
+ # ctx->net_offset += 8;
+ # return false;
+ # }
+ # """
+ # for typ in typs:
+ # inline = (
+ # " FLATTEN"
+ # if (isinstance(typ, Struct) and typ.msgid is not None)
+ # else " ALWAYS_INLINE"
+ # )
+ # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ # ret += "\n"
+ # ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{"
+ # match typ:
+ # case Bitfield():
+ # ret += "\n"
+ # ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n"
+ # case Struct():
+ # if len(typ.members) == 0:
+ # ret += "\n\treturn false;\n"
+ # ret += "}\n"
+ # continue
+
+ # mark_offset = set()
+ # for member in typ.members:
+ # if member.valexpr:
+ # if member.name not in mark_offset:
+ # ret += f"\n\tuint32_t _{member.name}_offset;"
+ # mark_offset.add(member.name)
+ # for tok in member.valexpr:
+ # if (
+ # isinstance(tok, ExprVal)
+ # and tok.name.startswith("&")
+ # and tok.name[1:] not in mark_offset
+ # ):
+ # ret += f"\n\tuint32_t _{tok.name[1:]}_offset;"
+ # mark_offset.add(tok.name[1:])
+
+ # prefix0 = "\treturn "
+ # prefix1 = "\t || "
+ # prefix2 = "\t "
+
+ # struct_versions = typ.members[0].ver
+ # prefix = prefix0
+ # for member in typ.members:
+ # ret += f"\n{prefix}"
+ # if member.ver != struct_versions:
+ # ret += "( " + c_ver_cond(idprefix, member.ver) + " && "
+ # if member.name in mark_offset:
+ # ret += f"({{ _{member.name}_offset = ctx->net_offset; "
+ # if member.cnt:
+ # ret += "({"
+ # ret += f"\n{prefix2}\tbool err = false;"
+ # ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
+ # if typ.name in ["d", "s"]: # SPECIAL
+ # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);"
+ # else:
+ # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
+ # ret += f"\n{prefix2}\terr;"
+ # ret += f"\n{prefix2}}})"
+ # elif member.valexpr:
+ # assert member.static_size
+ # ret += (
+ # f"({{ ctx->net_offset += {member.static_size}; false; }})"
+ # )
+ # else:
+ # ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
+ # if member.name in mark_offset:
+ # ret += "; })"
+ # if member.ver != struct_versions:
+ # ret += " )"
+ # prefix = prefix1
+
+ # for member in typ.members:
+ # if member.valexpr:
+ # assert member.static_size
+ # ret += f"\n{prefix}"
+ # ret += f"({{ encode_u{member.static_size*8}le("
+ # for tok in member.valexpr:
+ # match tok:
+ # case ExprOp():
+ # ret += f" {tok.op}"
+ # case ExprVal(name="end"):
+ # ret += " ctx->net_offset"
+ # case ExprVal():
+ # ret += f" _{tok.name[1:]}_offset"
+ # ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})"
+
+ # ret += ";\n"
+ # ret += "}\n"
+
+ # # vtables ##################################################################
+ # ret += f"""
+ # /* vtables ********************************************************************/
+
+ # #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\
+ # .basesize = sizeof(struct {idprefix}msg_##typ), \\
+ # .validate = validate_##typ, \\
+ # .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\
+ # .marshal = (_marshal_fn_t)marshal_##typ, \\
+ # }}
+
+ # struct _vtable_version _{idprefix}vtables[{c_ver_enum(idprefix, 'NUM')}] = {{
+ # """
+
+ # ret += f"\t[{c_ver_enum(idprefix, 'unknown')}] = {{ .msgs = {{\n"
+ # for msg in just_structs_msg(typs):
+ # if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL
+ # ret += f"\t\t_MSG({msg.name}),\n"
+ # ret += "\t}},\n"
+
+ # for ver in sorted(versions):
+ # ret += f"\t[{c_ver_enum(idprefix, ver)}] = {{ .msgs = {{\n"
+ # for msg in just_structs_msg(typs):
+ # if ver not in msg.msgver:
+ # continue
+ # ret += f"\t\t_MSG({msg.name}),\n"
+ # ret += "\t}},\n"
+ # ret += "};\n"
+
+ ############################################################################
+ return ret
+
+
+################################################################################
+
+
+class Parser:
+ cache: dict[str, tuple[str, list[Type]]] = {}
+
+ def parse_file(self, filename: str) -> tuple[str, list[Type]]:
+ filename = os.path.normpath(filename)
+ if filename not in self.cache:
+
+ def get_include(other_filename: str) -> tuple[str, list[Type]]:
+ return self.parse_file(os.path.join(filename, "..", other_filename))
+
+ self.cache[filename] = parse_file(filename, get_include)
+ return self.cache[filename]
+
+ def all(self) -> tuple[set[str], list[Type]]:
+ ret_versions: set[str] = set()
+ ret_typs: dict[str, Type] = {}
+ for version, typs in self.cache.values():
+ if version in ret_versions:
+ raise ValueError(f"duplicate protocol version {repr(version)}")
+ ret_versions.add(version)
+ for typ in typs:
+ if typ.name in ret_typs:
+ if typ != ret_typs[typ.name]:
+ raise ValueError(f"duplicate type name {repr(typ.name)}")
+ else:
+ ret_typs[typ.name] = typ
+ return ret_versions, list(ret_typs.values())
+
+
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 2:
+ raise ValueError("requires at least 1 .9p filename")
+ parser = Parser()
+ for txtname in sys.argv[1:]:
+ parser.parse_file(txtname)
+ versions, typs = parser.all()
+ outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
+ with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
+ fh.write(gen_h("lib9p_", versions, typs))
+ with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
+ fh.write(gen_c("lib9p_", versions, typs))