summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-23 03:05:18 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-23 03:05:18 -0600
commitc275032964505d3ceecf3cc0ce21b059ede930dd (patch)
tree0cb858fd7b55f19eda2d027e628b580aab155342
parent1e49f931a86e54415ba36fc2bfe799d616936080 (diff)
parent82b733e4f8b3febc3b51c133a52fb62b54180b4b (diff)
Merge branch 'lukeshu/9p-gen-split'
-rw-r--r--.editorconfig2
-rw-r--r--GNUmakefile4
-rw-r--r--lib9p/9p.generated.c2
-rwxr-xr-xlib9p/idl.gen1439
-rw-r--r--lib9p/idl/0000-TODO.md2
-rw-r--r--lib9p/include/lib9p/9p.generated.h2
-rwxr-xr-xlib9p/proto.gen15
-rw-r--r--lib9p/protogen/__init__.py57
-rw-r--r--lib9p/protogen/c.py200
-rw-r--r--lib9p/protogen/c9util.py117
-rw-r--r--lib9p/protogen/c_marshal.py357
-rw-r--r--lib9p/protogen/c_unmarshal.py92
-rw-r--r--lib9p/protogen/c_validate.py171
-rw-r--r--lib9p/protogen/cutil.py84
-rw-r--r--lib9p/protogen/h.py447
-rw-r--r--lib9p/protogen/idlutil.py112
16 files changed, 1658 insertions, 1445 deletions
diff --git a/.editorconfig b/.editorconfig
index 69fefd5..f907b33 100644
--- a/.editorconfig
+++ b/.editorconfig
@@ -36,7 +36,7 @@ _mode = sh
[{build-aux/lint-h,build-aux/lint-bin,build-aux/get-dscname,build-aux/linux-errno.txt.gen,libusb/include/libusb/tusb_helpers.h.gen,lib9p/tests/runtest}]
_mode = bash
-[{lib9p/idl.gen,lib9p/include/lib9p/linux-errno.h.gen,build-aux/stack.c.gen}]
+[{lib9p/proto.gen,lib9p/include/lib9p/linux-errno.h.gen,build-aux/stack.c.gen}]
_mode = python3
indent_style = space
indent_size = 4
diff --git a/GNUmakefile b/GNUmakefile
index 339ae4b..2d11c8b 100644
--- a/GNUmakefile
+++ b/GNUmakefile
@@ -42,7 +42,7 @@ lib9p/include/lib9p/linux-errno.h: %: %.gen 3rd-party/linux-errno.txt
$^ >$@
generate/files += lib9p/9p.generated.c lib9p/include/lib9p/9p.generated.h
-lib9p/9p.generated.c lib9p/include/lib9p/9p.generated.h &: lib9p/idl.gen lib9p/idl/__init__.py lib9p/idl lib9p/idl/*.9p
+lib9p/9p.generated.c lib9p/include/lib9p/9p.generated.h &: lib9p/proto.gen lib9p/idl/__init__.py lib9p/protogen lib9p/protogen/*.py lib9p/idl lib9p/idl/*.9p
$< $(filter %.9p,$^)
generate/files += lib9p/tests/test_compile.c
@@ -123,7 +123,7 @@ lint/python3: lint/%: build-aux/venv
./build-aux/venv/bin/black --check $(sources_$*)
./build-aux/venv/bin/isort --check $(sources_$*)
./build-aux/venv/bin/pylint $(sources_$*)
- ! grep -nh 'SPECIAL$$' -- lib9p/idl.gen
+ ! grep -nh 'SPECIAL$$' -- lib9p/proto.gen lib9p/protogen/*.py
lint/c: lint/%: build-aux/lint-h build-aux/get-dscname
./build-aux/lint-h $(filter %.h,$(sources_$*))
lint/make lint/cmake lint/gitignore lint/ini lint/9p lint/markdown lint/pip lint/man-cat: lint/%:
diff --git a/lib9p/9p.generated.c b/lib9p/9p.generated.c
index 8d6c82f..af5200f 100644
--- a/lib9p/9p.generated.c
+++ b/lib9p/9p.generated.c
@@ -1,4 +1,4 @@
-/* Generated by `lib9p/idl.gen lib9p/idl/2002-9P2000.9p lib9p/idl/2003-9P2000.p9p.9p lib9p/idl/2005-9P2000.u.9p lib9p/idl/2010-9P2000.L.9p lib9p/idl/2012-9P2000.e.9p`. DO NOT EDIT! */
+/* Generated by `lib9p/proto.gen lib9p/idl/2002-9P2000.9p lib9p/idl/2003-9P2000.p9p.9p lib9p/idl/2005-9P2000.u.9p lib9p/idl/2010-9P2000.L.9p lib9p/idl/2012-9P2000.e.9p`. DO NOT EDIT! */
#include <stdbool.h>
#include <stddef.h> /* for size_t */
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
deleted file mode 100755
index eaeca49..0000000
--- a/lib9p/idl.gen
+++ /dev/null
@@ -1,1439 +0,0 @@
-#!/usr/bin/env python
-# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files
-# defining 9P protocol variants.
-#
-# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
-# SPDX-License-Identifier: AGPL-3.0-or-later
-
-import enum
-import graphlib
-import os.path
-import sys
-import typing
-
-sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))
-import idl # pylint: disable=wrong-import-position,import-self
-
-# This strives to be "general-purpose" in that it just acts on the
-# *.9p inputs; but (unfortunately?) there are a few special-cases in
-# this script, marked with "SPECIAL".
-
-
-# Utilities ####################################################################
-
-idprefix = "lib9p_"
-
-u32max = (1 << 32) - 1
-u64max = (1 << 64) - 1
-
-
-def tab_ljust(s: str, width: int) -> str:
- cur = len(s.expandtabs(tabsize=8))
- if cur >= width:
- return s
- return s + " " * (width - cur)
-
-
-def add_prefix(p: str, s: str) -> str:
- if s.startswith("_"):
- return "_" + p + s[1:]
- return p + s
-
-
-def c_macro(full: str) -> str:
- full = full.rstrip()
- assert "\n" in full
- lines = [l.rstrip() for l in full.split("\n")]
- width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1])
- lines = [tab_ljust(l, width) for l in lines]
- return " \\\n".join(lines).rstrip() + "\n"
-
-
-def c_ver_enum(ver: str) -> str:
- return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
-
-
-def c_ver_ifdef(versions: typing.Collection[str]) -> str:
- return " || ".join(
- f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
- )
-
-
-def c_ver_cond(versions: typing.Collection[str]) -> str:
- if len(versions) == 1:
- v = next(v for v in versions)
- return f"is_ver(ctx, {v.replace('.', '_')})"
- return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-
-
-def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
- match typ:
- case idl.Primitive():
- if typ.value == 1 and parent and parent.cnt: # SPECIAL (string)
- return "[[gnu::nonstring]] char"
- return f"uint{typ.value*8}_t"
- case idl.Number():
- return f"{idprefix}{typ.typname}_t"
- case idl.Bitfield():
- return f"{idprefix}{typ.typname}_t"
- case idl.Message():
- return f"struct {idprefix}msg_{typ.typname}"
- case idl.Struct():
- return f"struct {idprefix}{typ.typname}"
- case _:
- raise ValueError(f"not a type: {typ.__class__.__name__}")
-
-
-def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
- ret: list[str] = []
- for tok in expr.tokens:
- match tok:
- case idl.ExprOp():
- ret.append(tok.op)
- case idl.ExprLit():
- ret.append(str(tok.val))
- case idl.ExprSym(symname="s32_max"):
- ret.append("INT32_MAX")
- case idl.ExprSym(symname="s64_max"):
- ret.append("INT64_MAX")
- case idl.ExprSym():
- ret.append(lookup_sym(tok.symname))
- case _:
- assert False
- return " ".join(ret)
-
-
-_ifdef_stack: list[str | None] = []
-
-
-def ifdef_push(n: int, _newval: str) -> str:
- # Grow the stack as needed
- while len(_ifdef_stack) < n:
- _ifdef_stack.append(None)
-
- # Set some variables
- parentval: str | None = None
- for x in _ifdef_stack[:-1]:
- if x is not None:
- parentval = x
- oldval = _ifdef_stack[-1]
- newval: str | None = _newval
- if newval == parentval:
- newval = None
-
- # Put newval on the stack.
- _ifdef_stack[-1] = newval
-
- # Build output.
- ret = ""
- if newval != oldval:
- if oldval is not None:
- ret += f"#endif /* {oldval} */\n"
- if newval is not None:
- ret += f"#if {newval}\n"
- return ret
-
-
-def ifdef_pop(n: int) -> str:
- global _ifdef_stack
- ret = ""
- while len(_ifdef_stack) > n:
- if _ifdef_stack[-1] is not None:
- ret += f"#endif /* {_ifdef_stack[-1]} */\n"
- _ifdef_stack = _ifdef_stack[:-1]
- return ret
-
-
-# topo_sorted() ################################################################
-
-
-def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
- ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter()
- for typ in typs:
- match typ:
- case idl.Number():
- ts.add(typ)
- case idl.Bitfield():
- ts.add(typ)
- case idl.Struct(): # and idl.Message():
- deps = [
- member.typ
- for member in typ.members
- if not isinstance(member.typ, idl.Primitive)
- ]
- ts.add(typ, *deps)
- return ts.static_order()
-
-
-# walk() #######################################################################
-
-
-class Path:
- root: idl.Type
- elems: list[idl.StructMember]
-
- def __init__(
- self, root: idl.Type, elems: list[idl.StructMember] | None = None
- ) -> None:
- self.root = root
- self.elems = elems if elems is not None else []
-
- def add(self, elem: idl.StructMember) -> "Path":
- return Path(self.root, self.elems + [elem])
-
- def parent(self) -> "Path":
- return Path(self.root, self.elems[:-1])
-
- def c_str(self, base: str, loopdepth: int = 0) -> str:
- ret = base
- for i, elem in enumerate(self.elems):
- if i > 0:
- ret += "."
- ret += elem.membname
- if elem.cnt:
- ret += f"[{chr(ord('i')+loopdepth)}]"
- loopdepth += 1
- return ret
-
- def __str__(self) -> str:
- return self.c_str(self.root.typname + "->")
-
-
-class WalkCmd(enum.Enum):
- KEEP_GOING = 1
- DONT_RECURSE = 2
- ABORT = 3
-
-
-type WalkHandler = typing.Callable[
- [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
-]
-
-
-def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
- typ = path.elems[-1].typ if path.elems else path.root
-
- ret, atexit = handle(path)
-
- if isinstance(typ, idl.Struct):
- match ret:
- case WalkCmd.KEEP_GOING:
- for member in typ.members:
- if _walk(path.add(member), handle) == WalkCmd.ABORT:
- ret = WalkCmd.ABORT
- break
- case WalkCmd.DONT_RECURSE:
- ret = WalkCmd.KEEP_GOING
- case WalkCmd.ABORT:
- ret = WalkCmd.ABORT
- case _:
- assert False, f"invalid cmd: {ret}"
-
- if atexit:
- atexit()
- return ret
-
-
-def walk(typ: idl.Type, handle: WalkHandler) -> None:
- _walk(Path(typ), handle)
-
-
-# get_buffer_size() ############################################################
-
-
-class BufferSize(typing.NamedTuple):
- min_size: int # really just here to sanity-check against typ.min_size(version)
- exp_size: int # "expected" or max-reasonable size
- max_size: int # really just here to sanity-check against typ.max_size(version)
- max_copy: int
- max_copy_extra: str
- max_iov: int
- max_iov_extra: str
-
-
-class TmpBufferSize:
- min_size: int
- exp_size: int
- max_size: int
- max_copy: int
- max_copy_extra: str
- max_iov: int
- max_iov_extra: str
-
- tmp_starts_with_copy: bool
- tmp_ends_with_copy: bool
-
- def __init__(self) -> None:
- self.min_size = 0
- self.exp_size = 0
- self.max_size = 0
- self.max_copy = 0
- self.max_copy_extra = ""
- self.max_iov = 0
- self.max_iov_extra = ""
- self.tmp_starts_with_copy = False
- self.tmp_ends_with_copy = False
-
-
-def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize:
- assert isinstance(typ, idl.Primitive) or (version in typ.in_versions)
-
- ret = TmpBufferSize()
-
- if not isinstance(typ, idl.Struct):
- assert typ.static_size
- ret.min_size = typ.static_size
- ret.exp_size = typ.static_size
- ret.max_size = typ.static_size
- ret.max_copy = typ.static_size
- ret.max_iov = 1
- ret.tmp_starts_with_copy = True
- ret.tmp_ends_with_copy = True
- return ret
-
- def handle(path: Path) -> tuple[WalkCmd, None]:
- nonlocal ret
- if path.elems:
- child = path.elems[-1]
- if version not in child.in_versions:
- return WalkCmd.DONT_RECURSE, None
- if child.cnt:
- if child.typ.static_size == 1: # SPECIAL (zerocopy)
- ret.max_iov += 1
- # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data
- ret.exp_size += 27 if child.membname == "utf8" else 8192
- ret.max_size += child.max_cnt
- ret.tmp_ends_with_copy = False
- return WalkCmd.DONT_RECURSE, None
- sub = _get_buffer_size(child.typ, version)
- ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM
- ret.max_size += sub.max_size * child.max_cnt
- if child.membname == "wname" and path.root.typname in (
- "Tsread",
- "Tswrite",
- ): # SPECIAL (9P2000.e)
- assert ret.tmp_ends_with_copy
- assert sub.tmp_starts_with_copy
- assert not sub.tmp_ends_with_copy
- ret.max_copy_extra = (
- f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})"
- )
- ret.max_iov_extra = (
- f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})"
- )
- ret.max_iov -= 1
- else:
- ret.max_copy += sub.max_copy * child.max_cnt
- if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy
- ret.max_iov += 1
- else: # contains zero-copy segments
- ret.max_iov += sub.max_iov * child.max_cnt
- if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy:
- # we can merge this one
- ret.max_iov -= 1
- if (
- sub.tmp_ends_with_copy
- and sub.tmp_starts_with_copy
- and sub.max_iov > 1
- ):
- # we can merge these
- ret.max_iov -= child.max_cnt - 1
- ret.tmp_ends_with_copy = sub.tmp_ends_with_copy
- return WalkCmd.DONT_RECURSE, None
- if not isinstance(child.typ, idl.Struct):
- assert child.typ.static_size
- if not ret.tmp_ends_with_copy:
- if ret.max_size == 0:
- ret.tmp_starts_with_copy = True
- ret.max_iov += 1
- ret.tmp_ends_with_copy = True
- ret.min_size += child.typ.static_size
- ret.exp_size += child.typ.static_size
- ret.max_size += child.typ.static_size
- ret.max_copy += child.typ.static_size
- return WalkCmd.KEEP_GOING, None
-
- walk(typ, handle)
- assert ret.min_size == typ.min_size(version)
- assert ret.max_size == typ.max_size(version)
- return ret
-
-
-def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
- tmp = _get_buffer_size(typ, version)
- return BufferSize(
- min_size=tmp.min_size,
- exp_size=tmp.exp_size,
- max_size=tmp.max_size,
- max_copy=tmp.max_copy,
- max_copy_extra=tmp.max_copy_extra,
- max_iov=tmp.max_iov,
- max_iov_extra=tmp.max_iov_extra,
- )
-
-
-# Generate .h ##################################################################
-
-
-def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
- global _ifdef_stack
- _ifdef_stack = []
-
- ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
-
-#ifndef _LIB9P_9P_H_
-\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
-#endif
-
-#include <stdint.h> /* for uint{{n}}_t types */
-
-#include <libhw/generic/net.h> /* for struct iovec */
-"""
-
- id2typ: dict[int, idl.Message] = {}
- for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
- id2typ[msg.msgid] = msg
-
- ret += """
-/* config *********************************************************************/
-
-#include "config.h"
-"""
- for ver in sorted(versions):
- ret += "\n"
- ret += f"#ifndef {c_ver_ifdef({ver})}\n"
- ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n"
- if ver == "9P2000.e": # SPECIAL (9P2000.e)
- ret += "#else\n"
- ret += f"\t#if {c_ver_ifdef({ver})}\n"
- ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n"
- ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
- ret += "\t\t#endif\n"
- ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
- ret += "\t#endif\n"
- ret += "#endif\n"
-
- ret += f"""
-/* enum version ***************************************************************/
-
-enum {idprefix}version {{
-"""
- fullversions = ["unknown = 0", *sorted(versions)]
- verwidth = max(len(v) for v in fullversions)
- for ver in fullversions:
- if ver in versions:
- ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t{c_ver_enum(ver)},"
- ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
- ret += ifdef_pop(0)
- ret += f"\t{c_ver_enum('NUM')},\n"
- ret += "};\n"
-
- ret += """
-/* enum msg_type **************************************************************/
-
-"""
- ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
- namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message))
- for n in range(0x100):
- if n not in id2typ:
- continue
- msg = id2typ[n]
- ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
- ret += f"\t{idprefix.upper()}TYP_{msg.typname:<{namewidth}} = {msg.msgid},\n"
- ret += ifdef_pop(0)
- ret += "};\n"
-
- ret += """
-/* payload types **************************************************************/
-"""
-
- def per_version_comment(
- typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str]
- ) -> str:
- lines: dict[str, str] = {}
- for version in sorted(typ.in_versions):
- lines[version] = fn(typ, version)
- if len(set(lines.values())) == 1:
- for _, line in lines.items():
- return f"/* {line} */\n"
- assert False
- else:
- ret = ""
- v_width = max(len(c_ver_enum(v)) for v in typ.in_versions)
- for version, line in lines.items():
- ret += f"/* {c_ver_enum(version):<{v_width}}: {line} */\n"
- return ret
-
- for typ in topo_sorted(typs):
- ret += "\n"
- ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
-
- def sum_size(typ: idl.UserType, version: str) -> str:
- sz = get_buffer_size(typ, version)
- assert (
- sz.min_size <= sz.exp_size
- and sz.exp_size <= sz.max_size
- and sz.max_size < u64max
- )
- ret = ""
- if sz.min_size == sz.max_size:
- ret += f"size = {sz.min_size:,}"
- else:
- ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}"
- if sz.max_size > u32max:
- ret += " (warning: >UINT32_MAX)"
- ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}"
- return ret
-
- ret += per_version_comment(typ, sum_size)
-
- match typ:
- case idl.Number():
- ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
- prefix = f"{idprefix.upper()}{typ.typname.upper()}_"
- namewidth = max(len(name) for name in typ.vals)
- for name, val in typ.vals.items():
- ret += f"#define {prefix}{name:<{namewidth}} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
- case idl.Bitfield():
- ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
-
- def bitname(val: idl.Bit | idl.BitAlias) -> str:
- s = val.bitname
- match val:
- case idl.Bit(cat=idl.BitCat.RESERVED):
- s = "_RESERVED_" + s
- case idl.Bit(cat=idl.BitCat.SUBFIELD):
- assert isinstance(typ, idl.Bitfield)
- n = sum(
- 1
- for b in typ.bits[: val.num]
- if b.cat == idl.BitCat.SUBFIELD
- and b.bitname == val.bitname
- )
- s = f"_{s}_{n}"
- case idl.Bit(cat=idl.BitCat.UNUSED):
- return ""
- return add_prefix(f"{idprefix.upper()}{typ.typname.upper()}_", s)
-
- namewidth = max(
- len(bitname(val)) for val in [*typ.bits, *typ.names.values()]
- )
-
- ret += "\n"
- for bit in reversed(typ.bits):
- vers = bit.in_versions
- if bit.cat == idl.BitCat.UNUSED:
- vers = typ.in_versions
- ret += ifdef_push(2, c_ver_ifdef(vers))
-
- # It is important all of the `beg` strings have
- # the same length.
- end = ""
- match bit.cat:
- case (
- idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD
- ):
- if _ifdef_stack[-1]:
- beg = "# define"
- else:
- beg = "#define "
- case idl.BitCat.UNUSED:
- beg = "/* unused"
- end = " */"
-
- c_name = bitname(bit)
- c_val = f"1<<{bit.num}"
- ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n"
- if aliases := [
- alias
- for alias in typ.names.values()
- if isinstance(alias, idl.BitAlias)
- ]:
- ret += "\n"
-
- for alias in aliases:
- ret += ifdef_push(2, c_ver_ifdef(alias.in_versions))
-
- end = ""
- if _ifdef_stack[-1]:
- beg = "# define"
- else:
- beg = "#define "
-
- c_name = bitname(alias)
- c_val = alias.val
- ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n"
- ret += ifdef_pop(1)
- del bitname
- case idl.Struct(): # and idl.Message():
- ret += c_typename(typ) + " {"
- if not typ.members:
- ret += "};\n"
- continue
- ret += "\n"
-
- typewidth = max(len(c_typename(m.typ, m)) for m in typ.members)
-
- for member in typ.members:
- if member.val:
- continue
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n"
- ret += ifdef_pop(1)
- ret += "};\n"
- del typ
- ret += ifdef_pop(0)
-
- ret += """
-/* containers *****************************************************************/
-"""
- ret += "\n"
- ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n"
-
- tmsg_max_iov: dict[str, int] = {}
- tmsg_max_copy: dict[str, int] = {}
- rmsg_max_iov: dict[str, int] = {}
- rmsg_max_copy: dict[str, int] = {}
- for typ in typs:
- if not isinstance(typ, idl.Message):
- continue
- if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e)
- continue
- max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov
- max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy
- for version in typ.in_versions:
- if version not in max_iov:
- max_iov[version] = 0
- max_copy[version] = 0
- sz = get_buffer_size(typ, version)
- if sz.max_iov > max_iov[version]:
- max_iov[version] = sz.max_iov
- if sz.max_copy > max_copy[version]:
- max_copy[version] = sz.max_copy
-
- for name, table in [
- ("tmsg_max_iov", tmsg_max_iov),
- ("tmsg_max_copy", tmsg_max_copy),
- ("rmsg_max_iov", rmsg_max_iov),
- ("rmsg_max_copy", rmsg_max_copy),
- ]:
- inv: dict[int, set[str]] = {}
- for version, maxval in table.items():
- if maxval not in inv:
- inv[maxval] = set()
- inv[maxval].add(version)
-
- ret += "\n"
- directive = "if"
- seen_e = False # SPECIAL (9P2000.e)
- for maxval in sorted(inv, reverse=True):
- ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n"
- indent = 1
- if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
- typ = next(typ for typ in typs if typ.typname == "Tswrite")
- sz = get_buffer_size(typ, "9P2000.e")
- match name:
- case "tmsg_max_iov":
- maxexpr = f"{sz.max_iov}{sz.max_iov_extra}"
- case "tmsg_max_copy":
- maxexpr = f"{sz.max_copy}{sz.max_copy_extra}"
- case _:
- assert False
- ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n"
- ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n"
- ret += "\t#else\n"
- indent += 1
- ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n"
- if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
- ret += "\t#endif\n"
- if "9P2000.e" in inv[maxval]:
- seen_e = True
- directive = "elif"
- ret += "#endif\n"
-
- ret += "\n"
- ret += f"struct {idprefix}Tmsg_send_buf {{\n"
- ret += "\tsize_t iov_cnt;\n"
- ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n"
- ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n"
- ret += "};\n"
-
- ret += "\n"
- ret += f"struct {idprefix}Rmsg_send_buf {{\n"
- ret += "\tsize_t iov_cnt;\n"
- ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n"
- ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n"
- ret += "};\n"
-
- return ret
-
-
-# Generate .c ##################################################################
-
-
-def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
- global _ifdef_stack
- _ifdef_stack = []
-
- ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
-
-#include <stdbool.h>
-#include <stddef.h> /* for size_t */
-#include <inttypes.h> /* for PRI* macros */
-#include <string.h> /* for memset() */
-
-#include <libmisc/assert.h>
-
-#include <lib9p/9p.h>
-
-#include "internal.h"
-"""
-
- # utilities ################################################################
- ret += """
-/* utilities ******************************************************************/
-"""
-
- def used(arg: str) -> str:
- return arg
-
- def unused(arg: str) -> str:
- return f"LM_UNUSED({arg})"
-
- id2typ: dict[int, idl.Message] = {}
- for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
- id2typ[msg.msgid] = msg
-
- def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
- ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
- for ver in ["unknown", *sorted(versions)]:
- if ver != "unknown":
- ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t[{c_ver_enum(ver)}] = {{\n"
- for n in range(*rng):
- xmsg: idl.Message | None = id2typ.get(n, None)
- if xmsg:
- if ver == "unknown": # SPECIAL (initialization)
- if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]:
- xmsg = None
- else:
- if ver not in xmsg.in_versions:
- xmsg = None
- if xmsg:
- ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n"
- ret += "\t},\n"
- ret += ifdef_pop(0)
- ret += "};\n"
- return ret
-
- for v in sorted(versions):
- ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
- ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n"
- ret += "#else\n"
- ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
- ret += "#endif\n"
- ret += "\n"
- ret += "/**\n"
- ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n"
- ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n"
- ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n"
- ret += " * several version checks together.\n"
- ret += " */\n"
- ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n"
-
- # strings ##################################################################
- ret += f"""
-/* strings ********************************************************************/
-
-const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{
-"""
- for ver in ["unknown", *sorted(versions)]:
- if ver in versions:
- ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
- ret += ifdef_pop(0)
- ret += "};\n"
-
- ret += "\n"
- ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
- ret += msg_table("msg", "name", "char *const", (0, 0x100, 1))
-
- # bitmasks #################################################################
- ret += """
-/* bitmasks *******************************************************************/
-"""
- for typ in typs:
- if not isinstance(typ, idl.Bitfield):
- continue
- ret += "\n"
- ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n"
- verwidth = max(len(ver) for ver in versions)
- for ver in sorted(versions):
- ret += ifdef_push(2, c_ver_ifdef({ver}))
- ret += (
- f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
- + "".join(
- (
- "1"
- if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD)
- and ver in bit.in_versions
- else "0"
- )
- for bit in reversed(typ.bits)
- )
- + ",\n"
- )
- ret += ifdef_pop(1)
- ret += "};\n"
- ret += ifdef_pop(0)
-
- # validate_* ###############################################################
- ret += """
-/* validate_* *****************************************************************/
-
-LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
-\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
-\t\t/* If needed-net-size overflowed uint32_t, then
-\t\t * there's no way that actual-net-size will live up to
-\t\t * that. */
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\tif (ctx->net_offset > ctx->net_size)
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
-\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
-\t\t/* If needed-host-size overflowed size_t, then there's
-\t\t * no way that actual-net-size will live up to
-\t\t * that. */
-\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
- size_t cnt,
- _validate_fn_t item_fn, size_t item_host_size) {
-\tfor (size_t i = 0; i < cnt; i++)
-\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
-\t\t\treturn true;
-\treturn false;
-}
-
-LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
-LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
-LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
-LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
-"""
-
- def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
- return bool(
- member.max or member.val or any(m.cnt == member for m in typ.members)
- )
-
- for typ in topo_sorted(typs):
- inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
- ret += "\n"
- ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
-
- match typ:
- case idl.Number():
- ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
- case idl.Bitfield():
- ret += f"\t if (validate_{typ.static_size}(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
- if typ.static_size == 1:
- ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
- else:
- ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
- ret += "\tif (val & ~mask)\n"
- ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
- ret += "\treturn false;\n"
- case idl.Struct(): # and idl.Message()
- if len(typ.members) == 0:
- ret += "\treturn false;\n"
- ret += "}\n"
- continue
-
- # Pass 1 - declare value variables
- for member in typ.members:
- if should_save_value(typ, member):
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ)} {member.membname};\n"
- ret += ifdef_pop(1)
-
- # Pass 2 - declare offset variables
- mark_offset: set[str] = set()
- for member in typ.members:
- for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
- if tok.symname[1:] not in mark_offset:
- ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
- mark_offset.add(tok.symname[1:])
-
- # Pass 3 - main pass
- ret += "\treturn false\n"
- for member in typ.members:
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += "\t || "
- if member.in_versions != typ.in_versions:
- ret += "( " + c_ver_cond(member.in_versions) + " && "
- if member.cnt is not None:
- if member.typ.static_size == 1: # SPECIAL (zerocopy)
- ret += f"_validate_size_net(ctx, {member.cnt.membname})"
- else:
- ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))"
- if typ.typname == "s": # SPECIAL (string)
- ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
- else:
- if should_save_value(typ, member):
- ret += "("
- if member.membname in mark_offset:
- ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
- ret += f"validate_{member.typ.typname}(ctx)"
- if member.membname in mark_offset:
- ret += "; })"
- if should_save_value(typ, member):
- nbytes = member.static_size
- assert nbytes
- if nbytes == 1:
- ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
- else:
- ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
- if member.in_versions != typ.in_versions:
- ret += " )"
- ret += "\n"
-
- # Pass 4 - validate ,max= and ,val= constraints
- for member in typ.members:
-
- def lookup_sym(sym: str) -> str:
- match sym:
- case "end":
- return "ctx->net_offset"
- case _:
- assert sym.startswith("&")
- return f"_{sym[1:]}_offset"
-
- if member.max:
- assert member.static_size
- nbits = member.static_size * 8
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
- if member.val:
- assert member.static_size
- nbits = member.static_size * 8
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
-
- ret += ifdef_pop(1)
- ret += "\t ;\n"
- ret += "}\n"
- ret += ifdef_pop(0)
-
- # unmarshal_* ##############################################################
- ret += """
-/* unmarshal_* ****************************************************************/
-
-LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
-\t*out = ctx->net_bytes[ctx->net_offset];
-\tctx->net_offset += 1;
-}
-
-LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
-\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]);
-\tctx->net_offset += 2;
-}
-
-LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
-\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]);
-\tctx->net_offset += 4;
-}
-
-LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
-\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]);
-\tctx->net_offset += 8;
-}
-"""
- for typ in topo_sorted(typs):
- inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
- ret += "\n"
- ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
- match typ:
- case idl.Number():
- ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n"
- case idl.Bitfield():
- ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n"
- case idl.Struct():
- ret += "\tmemset(out, 0, sizeof(*out));\n"
-
- for member in typ.members:
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- if member.val:
- ret += f"\tctx->net_offset += {member.static_size};\n"
- continue
- ret += "\t"
-
- prefix = "\t"
- if member.in_versions != typ.in_versions:
- ret += "if ( " + c_ver_cond(member.in_versions) + " ) "
- prefix = "\t\t"
- if member.cnt:
- if member.in_versions != typ.in_versions:
- ret += "{\n"
- ret += prefix
- if member.typ.static_size == 1: # SPECIAL (string, zerocopy)
- ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n"
- ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n"
- else:
- ret += f"out->{member.membname} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n"
- ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n"
- if member.in_versions != typ.in_versions:
- ret += "\t}\n"
- else:
- ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n"
- ret += ifdef_pop(1)
- ret += "}\n"
- ret += ifdef_pop(0)
-
- # marshal_* ################################################################
- ret += """
-/* marshal_* ******************************************************************/
-
-"""
- ret += c_macro(
- "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
- "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n"
- "\t\tctx->net_iov_cnt++;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n"
- "\tctx->net_iov_cnt++;\n"
- )
- ret += c_macro(
- "#define MARSHAL_BYTES(ctx, data, len)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n"
- "\tctx->net_copied_size += len;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n"
- )
- ret += c_macro(
- "#define MARSHAL_U8LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tctx->net_copied[ctx->net_copied_size] = val;\n"
- "\tctx->net_copied_size += 1;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n"
- )
- ret += c_macro(
- "#define MARSHAL_U16LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
- "\tctx->net_copied_size += 2;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n"
- )
- ret += c_macro(
- "#define MARSHAL_U32LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
- "\tctx->net_copied_size += 4;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n"
- )
- ret += c_macro(
- "#define MARSHAL_U64LE(ctx, val)\n"
- "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
- "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
- "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
- "\tctx->net_copied_size += 8;\n"
- "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n"
- )
-
- class OffsetExpr:
- static: int
- cond: dict[frozenset[str], "OffsetExpr"]
- rep: list[tuple[Path, "OffsetExpr"]]
-
- def __init__(self) -> None:
- self.static = 0
- self.rep = []
- self.cond = {}
-
- def add(self, other: "OffsetExpr") -> None:
- self.static += other.static
- self.rep += other.rep
- for k, v in other.cond.items():
- if k in self.cond:
- self.cond[k].add(v)
- else:
- self.cond[k] = v
-
- def gen_c(
- self,
- dsttyp: str,
- dstvar: str,
- root: str,
- indent_depth: int,
- loop_depth: int,
- ) -> str:
- oneline: list[str] = []
- multiline = ""
- if self.static:
- oneline.append(str(self.static))
- for cnt, sub in self.rep:
- if not sub.cond and not sub.rep:
- if sub.static == 1:
- oneline.append(cnt.c_str(root))
- else:
- oneline.append(f"({cnt.c_str(root)})*{sub.static}")
- continue
- loopvar = chr(ord("i") + loop_depth)
- multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
- multiline += sub.gen_c(
- "", dstvar, root, indent_depth + 1, loop_depth + 1
- )
- multiline += f"{'\t'*indent_depth}}}\n"
- for vers, sub in self.cond.items():
- multiline += ifdef_push(indent_depth + 1, c_ver_ifdef(vers))
- multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n"
- multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
- multiline += f"{'\t'*indent_depth}}}\n"
- multiline += ifdef_pop(indent_depth)
- if dsttyp:
- if not oneline:
- oneline.append("0")
- ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
- elif oneline:
- ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
- ret += multiline
- return ret
-
- type OffsetExprRecursion = typing.Callable[[Path], WalkCmd]
-
- def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
- if not isinstance(typ, idl.Struct):
- assert typ.static_size
- ret = OffsetExpr()
- ret.static = typ.static_size
- return ret
-
- stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]]
-
- def pop_root() -> None:
- assert False
-
- def pop_cond() -> None:
- nonlocal stack
- key = frozenset(stack[-1][0].elems[-1].in_versions)
- if key in stack[-2][1].cond:
- stack[-2][1].cond[key].add(stack[-1][1])
- else:
- stack[-2][1].cond[key] = stack[-1][1]
- stack = stack[:-1]
-
- def pop_rep() -> None:
- nonlocal stack
- member_path = stack[-1][0]
- member = member_path.elems[-1]
- assert member.cnt
- cnt_path = member_path.parent().add(member.cnt)
- stack[-2][1].rep.append((cnt_path, stack[-1][1]))
- stack = stack[:-1]
-
- def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]:
- nonlocal recurse
-
- ret = recurse(path)
- if ret != WalkCmd.KEEP_GOING:
- return ret, None
-
- nonlocal stack
- stack_len = len(stack)
-
- def pop() -> None:
- nonlocal stack
- nonlocal stack_len
- while len(stack) > stack_len:
- stack[-1][2]()
-
- if path.elems:
- child = path.elems[-1]
- parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
- if child.in_versions < parent.in_versions:
- stack.append((path, OffsetExpr(), pop_cond))
- if child.cnt:
- stack.append((path, OffsetExpr(), pop_rep))
- if not isinstance(child.typ, idl.Struct):
- assert child.typ.static_size
- stack[-1][1].static += child.typ.static_size
- return ret, pop
-
- stack = [(Path(typ), OffsetExpr(), pop_root)]
- walk(typ, handle)
- return stack[0][1]
-
- def go_to_end(path: Path) -> WalkCmd:
- return WalkCmd.KEEP_GOING
-
- def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]:
- def ret(path: Path) -> WalkCmd:
- if len(path.elems) == 1 and path.elems[0].membname == name:
- return WalkCmd.ABORT
- return WalkCmd.KEEP_GOING
-
- return ret
-
- for typ in typs:
- if not (
- isinstance(typ, idl.Message) or typ.typname == "stat"
- ): # SPECIAL (include stat)
- continue
- assert isinstance(typ, idl.Struct)
- ret += "\n"
- ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n"
-
- # Pass 1 - check size
- max_size = max(typ.max_size(v) for v in typ.in_versions)
-
- if max_size > u32max: # SPECIAL (9P2000.e)
- ret += get_offset_expr(typ, go_to_end).gen_c(
- "uint64_t", "needed_size", "val->", 1, 0
- )
- ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n"
- else:
- ret += get_offset_expr(typ, go_to_end).gen_c(
- "uint32_t", "needed_size", "val->", 1, 0
- )
- ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
- if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
- ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
- ret += f'\t\t\t"{typ.typname}",\n'
- ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
- ret += "\t\t\tctx->ctx->max_msg_size);\n"
- ret += "\t\treturn true;\n"
- ret += "\t}\n"
-
- # Pass 2 - write data
- ifdef_depth = 1
- stack: list[tuple[Path, bool]] = [(Path(typ), False)]
-
- def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]:
- nonlocal ret
- nonlocal ifdef_depth
- nonlocal stack
- stack_len = len(stack)
-
- def pop() -> None:
- nonlocal ret
- nonlocal ifdef_depth
- nonlocal stack
- nonlocal stack_len
- while len(stack) > stack_len:
- ret += f"{'\t'*(len(stack)-1)}}}\n"
- if stack[-1][1]:
- ifdef_depth -= 1
- ret += ifdef_pop(ifdef_depth)
- stack = stack[:-1]
-
- loopdepth = sum(1 for elem in path.elems if elem.cnt)
- struct = path.elems[-1].typ if path.elems else path.root
- if isinstance(struct, idl.Struct):
- offsets: list[str] = []
- for member in struct.members:
- if not member.val:
- continue
- for tok in member.val.tokens:
- if not isinstance(tok, idl.ExprSym):
- continue
- if tok.symname == "end" or tok.symname.startswith("&"):
- if tok.symname not in offsets:
- offsets.append(tok.symname)
- for name in offsets:
- name_prefix = "offsetof_" + "".join(
- m.membname + "_" for m in path.elems
- )
- if name == "end":
- if not path.elems:
- nonlocal max_size
- if max_size > u32max:
- ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
- else:
- ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n"
- continue
- recurse: OffsetExprRecursion = go_to_end
- else:
- assert name.startswith("&")
- name = name[1:]
- recurse = go_to_tok(name)
- expr = get_offset_expr(struct, recurse)
- expr_prefix = path.c_str("val->", loopdepth)
- if not expr_prefix.endswith(">"):
- expr_prefix += "."
- ret += expr.gen_c(
- "uint32_t",
- name_prefix + name,
- expr_prefix,
- len(stack),
- loopdepth,
- )
- if path.elems:
- child = path.elems[-1]
- parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
- if child.in_versions < parent.in_versions:
- ret += ifdef_push(ifdef_depth + 1, c_ver_ifdef(child.in_versions))
- ifdef_depth += 1
- ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n"
- stack.append((path, True))
- if child.cnt:
- cnt_path = path.parent().add(child.cnt)
- if child.typ.static_size == 1: # SPECIAL (zerocopy)
- if path.root.typname == "stat": # SPECIAL (stat)
- ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
- else:
- ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
- return WalkCmd.KEEP_GOING, pop
- loopvar = chr(ord("i") + loopdepth - 1)
- ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
- stack.append((path, False))
- if not isinstance(child.typ, idl.Struct):
- if child.val:
-
- def lookup_sym(sym: str) -> str:
- nonlocal path
- if sym.startswith("&"):
- sym = sym[1:]
- return (
- "offsetof_"
- + "".join(m.membname + "_" for m in path.elems[:-1])
- + sym
- )
-
- val = c_expr(child.val, lookup_sym)
- else:
- val = path.c_str("val->")
- if isinstance(child.typ, idl.Bitfield):
- val += f" & {child.typ.typname}_masks[ctx->ctx->version]"
- ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
- return WalkCmd.KEEP_GOING, pop
-
- walk(typ, handle)
- del handle
- del stack
- del max_size
-
- ret += "\treturn false;\n"
- ret += "}\n"
- ret += ifdef_pop(0)
-
- # function tables ##########################################################
- ret += """
-/* function tables ************************************************************/
-"""
-
- ret += "\n"
- ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n"
- rerror = next(typ for typ in typs if typ.typname == "Rerror")
- ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization)
- for ver in sorted(versions):
- ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n"
- ret += ifdef_pop(0)
- ret += "};\n"
-
- ret += "\n"
- ret += c_macro(
- f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
- f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n"
- f"\t\t.validate = validate_##typ,\n"
- f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n"
- f"\t}}\n"
- )
- ret += c_macro(
- f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
- f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n"
- f"\t}}\n"
- )
- ret += "\n"
- ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2))
- ret += "\n"
- ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2))
- ret += "\n"
- ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2))
- ret += "\n"
- ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2))
-
- ret += f"""
-LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{
-\treturn validate_stat(ctx);
-}}
-LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{
-\tunmarshal_stat(ctx, out);
-}}
-LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{
-\treturn marshal_stat(ctx, val);
-}}
-"""
-
- ############################################################################
- return ret
-
-
-# Main #########################################################################
-
-
-def main() -> None:
- if typing.TYPE_CHECKING:
-
- class ANSIColors:
- MAGENTA = "\x1b[35m"
- RED = "\x1b[31m"
- RESET = "\x1b[0m"
-
- else:
- from _colorize import ANSIColors # Present in Python 3.13+
-
- if len(sys.argv) < 2:
- raise ValueError("requires at least 1 .9p filename")
- parser = idl.Parser()
- for txtname in sys.argv[1:]:
- try:
- parser.parse_file(txtname)
- except SyntaxError as e:
- print(
- f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}",
- file=sys.stderr,
- )
- assert e.text
- print(f"\t{e.text}", file=sys.stderr)
- text_suffix = e.text.lstrip()
- text_prefix = e.text[: -len(text_suffix)]
- print(
- f"\t{text_prefix}{ANSIColors.RED}{'~'*len(text_suffix)}{ANSIColors.RESET}",
- file=sys.stderr,
- )
- sys.exit(2)
- versions, typs = parser.all()
- outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
- with open(
- os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8"
- ) as fh:
- fh.write(gen_h(versions, typs))
- with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh:
- fh.write(gen_c(versions, typs))
-
-
-if __name__ == "__main__":
- main()
diff --git a/lib9p/idl/0000-TODO.md b/lib9p/idl/0000-TODO.md
index d196ac9..81cafe5 100644
--- a/lib9p/idl/0000-TODO.md
+++ b/lib9p/idl/0000-TODO.md
@@ -1,6 +1,6 @@
<!--
lib9p/idl/0000-TODO.md - Changes I intend to make to idl/__init__.py
- and idl.gen
+ and proto.gen
Copyright (C) 2025 Luke T. Shumaker <lukeshu@lukeshu.com>
SPDX-License-Identifier: AGPL-3.0-or-later
diff --git a/lib9p/include/lib9p/9p.generated.h b/lib9p/include/lib9p/9p.generated.h
index 725e781..7a50537 100644
--- a/lib9p/include/lib9p/9p.generated.h
+++ b/lib9p/include/lib9p/9p.generated.h
@@ -1,4 +1,4 @@
-/* Generated by `lib9p/idl.gen lib9p/idl/2002-9P2000.9p lib9p/idl/2003-9P2000.p9p.9p lib9p/idl/2005-9P2000.u.9p lib9p/idl/2010-9P2000.L.9p lib9p/idl/2012-9P2000.e.9p`. DO NOT EDIT! */
+/* Generated by `lib9p/proto.gen lib9p/idl/2002-9P2000.9p lib9p/idl/2003-9P2000.p9p.9p lib9p/idl/2005-9P2000.u.9p lib9p/idl/2010-9P2000.L.9p lib9p/idl/2012-9P2000.e.9p`. DO NOT EDIT! */
#ifndef _LIB9P_9P_H_
#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
diff --git a/lib9p/proto.gen b/lib9p/proto.gen
new file mode 100755
index 0000000..60f1347
--- /dev/null
+++ b/lib9p/proto.gen
@@ -0,0 +1,15 @@
+#!/usr/bin/env python
+# lib9p/proto.gen - Generate C marshalers/unmarshalers for .9p files
+# defining 9P protocol variants.
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import os.path
+import sys
+
+sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))
+import protogen # pylint: disable=wrong-import-position
+
+if __name__ == "__main__":
+ protogen.main()
diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py
new file mode 100644
index 0000000..c2c6173
--- /dev/null
+++ b/lib9p/protogen/__init__.py
@@ -0,0 +1,57 @@
+# lib9p/protogen/__init__.py - Generate C marshalers/unmarshalers for
+# .9p files defining 9P protocol variants
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import os.path
+import sys
+import typing
+
+import idl
+
+from . import c, h
+
+# pylint: disable=unused-variable
+__all__ = ["main"]
+
+
+def main() -> None:
+ if typing.TYPE_CHECKING:
+
+ class ANSIColors:
+ MAGENTA = "\x1b[35m"
+ RED = "\x1b[31m"
+ RESET = "\x1b[0m"
+
+ else:
+ from _colorize import ANSIColors # Present in Python 3.13+
+
+ if len(sys.argv) < 2:
+ raise ValueError("requires at least 1 .9p filename")
+ parser = idl.Parser()
+ for txtname in sys.argv[1:]:
+ try:
+ parser.parse_file(txtname)
+ except SyntaxError as e:
+ print(
+ f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}",
+ file=sys.stderr,
+ )
+ assert e.text
+ print(f"\t{e.text}", file=sys.stderr)
+ text_suffix = e.text.lstrip()
+ text_prefix = e.text[: -len(text_suffix)]
+ print(
+ f"\t{text_prefix}{ANSIColors.RED}{'~'*len(text_suffix)}{ANSIColors.RESET}",
+ file=sys.stderr,
+ )
+ sys.exit(2)
+ versions, typs = parser.all()
+ outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
+ with open(
+ os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8"
+ ) as fh:
+ fh.write(h.gen_h(versions, typs))
+ with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh:
+ fh.write(c.gen_c(versions, typs))
diff --git a/lib9p/protogen/c.py b/lib9p/protogen/c.py
new file mode 100644
index 0000000..a7e1773
--- /dev/null
+++ b/lib9p/protogen/c.py
@@ -0,0 +1,200 @@
+# lib9p/protogen/c.py - Generate 9p.generated.c
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import sys
+
+import idl
+
+from . import c9util, c_marshal, c_unmarshal, c_validate, cutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c"]
+
+
+def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
+ cutil.ifdef_init()
+
+ ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
+
+#include <stdbool.h>
+#include <stddef.h> /* for size_t */
+#include <inttypes.h> /* for PRI* macros */
+#include <string.h> /* for memset() */
+
+#include <libmisc/assert.h>
+
+#include <lib9p/9p.h>
+
+#include "internal.h"
+"""
+
+ # utilities ################################################################
+ ret += """
+/* utilities ******************************************************************/
+"""
+
+ id2typ: dict[int, idl.Message] = {}
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
+ id2typ[msg.msgid] = msg
+
+ def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
+ ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
+ for ver in ["unknown", *sorted(versions)]:
+ if ver != "unknown":
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f"\t[{c9util.ver_enum(ver)}] = {{\n"
+ for n in range(*rng):
+ xmsg: idl.Message | None = id2typ.get(n, None)
+ if xmsg:
+ if ver == "unknown": # SPECIAL (initialization)
+ if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]:
+ xmsg = None
+ else:
+ if ver not in xmsg.in_versions:
+ xmsg = None
+ if xmsg:
+ ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n"
+ ret += "\t},\n"
+ ret += cutil.ifdef_pop(0)
+ ret += "};\n"
+ return ret
+
+ for v in sorted(versions):
+ ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
+ ret += (
+ f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n"
+ )
+ ret += "#else\n"
+ ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
+ ret += "#endif\n"
+ ret += "\n"
+ ret += "/**\n"
+ ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n"
+ ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n"
+ ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n"
+ ret += " * several version checks together.\n"
+ ret += " */\n"
+ ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n"
+
+ # strings ##################################################################
+ ret += f"""
+/* strings ********************************************************************/
+
+const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{
+"""
+ for ver in ["unknown", *sorted(versions)]:
+ if ver in versions:
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n'
+ ret += cutil.ifdef_pop(0)
+ ret += "};\n"
+
+ ret += "\n"
+ ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n"
+ ret += msg_table("msg", "name", "char *const", (0, 0x100, 1))
+
+ # bitmasks #################################################################
+ ret += """
+/* bitmasks *******************************************************************/
+"""
+ for typ in typs:
+ if not isinstance(typ, idl.Bitfield):
+ continue
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n"
+ verwidth = max(len(ver) for ver in versions)
+ for ver in sorted(versions):
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver}))
+ ret += (
+ f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
+ + "".join(
+ (
+ "1"
+ if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD)
+ and ver in bit.in_versions
+ else "0"
+ )
+ for bit in reversed(typ.bits)
+ )
+ + ",\n"
+ )
+ ret += cutil.ifdef_pop(1)
+ ret += "};\n"
+ ret += cutil.ifdef_pop(0)
+
+ # validate_* ###############################################################
+ ret += c_validate.gen_c_validate(versions, typs)
+
+ # unmarshal_* ##############################################################
+ ret += c_unmarshal.gen_c_unmarshal(versions, typs)
+
+ # marshal_* ################################################################
+ ret += c_marshal.gen_c_marshal(versions, typs)
+
+ # function tables ##########################################################
+ ret += """
+/* function tables ************************************************************/
+"""
+
+ ret += "\n"
+ ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n"
+ rerror = next(typ for typ in typs if typ.typname == "Rerror")
+ ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization)
+ for ver in sorted(versions):
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n"
+ ret += cutil.ifdef_pop(0)
+ ret += "};\n"
+
+ ret += "\n"
+ ret += cutil.macro(
+ f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n"
+ f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n"
+ f"\t\t.validate = validate_##typ,\n"
+ f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n"
+ f"\t}}\n"
+ )
+ ret += cutil.macro(
+ f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n"
+ f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n"
+ f"\t}}\n"
+ )
+ ret += "\n"
+ ret += msg_table(
+ "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2)
+ )
+ ret += "\n"
+ ret += msg_table(
+ "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2)
+ )
+ ret += "\n"
+ ret += msg_table(
+ "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2)
+ )
+ ret += "\n"
+ ret += msg_table(
+ "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2)
+ )
+
+ ret += f"""
+LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{
+\treturn validate_stat(ctx);
+}}
+LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{
+\tunmarshal_stat(ctx, out);
+}}
+LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{
+\treturn marshal_stat(ctx, val);
+}}
+"""
+
+ ############################################################################
+ return ret
diff --git a/lib9p/protogen/c9util.py b/lib9p/protogen/c9util.py
new file mode 100644
index 0000000..f9c49fc
--- /dev/null
+++ b/lib9p/protogen/c9util.py
@@ -0,0 +1,117 @@
+# lib9p/protogen/c9util.py - Utilities for generating lib9p-specific C
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import typing
+
+import idl
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+# pylint: disable=unused-variable
+__all__ = [
+ "add_prefix",
+ "ident",
+ "Ident",
+ "IDENT",
+ "ver_enum",
+ "ver_ifdef",
+ "ver_cond",
+ "typename",
+ "idl_expr",
+]
+
+# idents #######################################################################
+
+
+def add_prefix(p: str, s: str) -> str:
+ if s.startswith("_"):
+ return "_" + p + s[1:]
+ return p + s
+
+
+def _ident(p: str, s: str) -> str:
+ return add_prefix(p, s.replace(".", "_"))
+
+
+def ident(s: str) -> str:
+ return _ident("lib9p_", s)
+
+
+def Ident(s: str) -> str:
+ return _ident("lib9p_".upper(), s)
+
+
+def IDENT(s: str) -> str:
+ return _ident("lib9p_", s).upper()
+
+
+# versions #####################################################################
+
+
+def ver_enum(ver: str) -> str:
+ return Ident("VER_" + ver)
+
+
+def ver_ifdef(versions: typing.Collection[str]) -> str:
+ return " || ".join(
+ f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
+ )
+
+
+def ver_cond(versions: typing.Collection[str]) -> str:
+ if len(versions) == 1:
+ v = next(v for v in versions)
+ return f"is_ver(ctx, {v.replace('.', '_')})"
+ return "( " + (" || ".join(ver_cond({v}) for v in sorted(versions))) + " )"
+
+
+# misc #########################################################################
+
+
+def typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
+ match typ:
+ case idl.Primitive():
+ if typ.value == 1 and parent and parent.cnt: # SPECIAL (string)
+ return "[[gnu::nonstring]] char"
+ return f"uint{typ.value*8}_t"
+ case idl.Number():
+ return ident(f"{typ.typname}_t")
+ case idl.Bitfield():
+ return ident(f"{typ.typname}_t")
+ case idl.Message():
+ return f"struct {ident(f'msg_{typ.typname}')}"
+ case idl.Struct():
+ return f"struct {ident(typ.typname)}"
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+
+def idl_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
+ ret: list[str] = []
+ for tok in expr.tokens:
+ match tok:
+ case idl.ExprOp():
+ ret.append(tok.op)
+ case idl.ExprLit():
+ ret.append(str(tok.val))
+ case idl.ExprSym(symname="s32_max"):
+ ret.append("INT32_MAX")
+ case idl.ExprSym(symname="s64_max"):
+ ret.append("INT64_MAX")
+ case idl.ExprSym():
+ ret.append(lookup_sym(tok.symname))
+ case _:
+ assert False
+ return " ".join(ret)
+
+
+def arg_used(arg: str) -> str:
+ return arg
+
+
+def arg_unused(arg: str) -> str:
+ return f"LM_UNUSED({arg})"
diff --git a/lib9p/protogen/c_marshal.py b/lib9p/protogen/c_marshal.py
new file mode 100644
index 0000000..152206d
--- /dev/null
+++ b/lib9p/protogen/c_marshal.py
@@ -0,0 +1,357 @@
+# lib9p/protogen/c_marshal.py - Generate C marshal functions
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import typing
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c_marshal"]
+
+
+def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
+ ret = """
+/* marshal_* ******************************************************************/
+
+"""
+ ret += cutil.macro(
+ "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
+ "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n"
+ "\t\tctx->net_iov_cnt++;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n"
+ "\tctx->net_iov_cnt++;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_BYTES(ctx, data, len)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n"
+ "\tctx->net_copied_size += len;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U8LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tctx->net_copied[ctx->net_copied_size] = val;\n"
+ "\tctx->net_copied_size += 1;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U16LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
+ "\tctx->net_copied_size += 2;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U32LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
+ "\tctx->net_copied_size += 4;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n"
+ )
+ ret += cutil.macro(
+ "#define MARSHAL_U64LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
+ "\tctx->net_copied_size += 8;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n"
+ )
+
+ class OffsetExpr:
+ static: int
+ cond: dict[frozenset[str], "OffsetExpr"]
+ rep: list[tuple[idlutil.Path, "OffsetExpr"]]
+
+ def __init__(self) -> None:
+ self.static = 0
+ self.rep = []
+ self.cond = {}
+
+ def add(self, other: "OffsetExpr") -> None:
+ self.static += other.static
+ self.rep += other.rep
+ for k, v in other.cond.items():
+ if k in self.cond:
+ self.cond[k].add(v)
+ else:
+ self.cond[k] = v
+
+ def gen_c(
+ self,
+ dsttyp: str,
+ dstvar: str,
+ root: str,
+ indent_depth: int,
+ loop_depth: int,
+ ) -> str:
+ oneline: list[str] = []
+ multiline = ""
+ if self.static:
+ oneline.append(str(self.static))
+ for cnt, sub in self.rep:
+ if not sub.cond and not sub.rep:
+ if sub.static == 1:
+ oneline.append(cnt.c_str(root))
+ else:
+ oneline.append(f"({cnt.c_str(root)})*{sub.static}")
+ continue
+ loopvar = chr(ord("i") + loop_depth)
+ multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
+ multiline += sub.gen_c(
+ "", dstvar, root, indent_depth + 1, loop_depth + 1
+ )
+ multiline += f"{'\t'*indent_depth}}}\n"
+ for vers, sub in self.cond.items():
+ multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
+ multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
+ multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
+ multiline += f"{'\t'*indent_depth}}}\n"
+ multiline += cutil.ifdef_pop(indent_depth)
+ if dsttyp:
+ if not oneline:
+ oneline.append("0")
+ ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
+ elif oneline:
+ ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
+ ret += multiline
+ return ret
+
+ type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]
+
+ def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
+ if not isinstance(typ, idl.Struct):
+ assert typ.static_size
+ ret = OffsetExpr()
+ ret.static = typ.static_size
+ return ret
+
+ stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]]
+
+ def pop_root() -> None:
+ assert False
+
+ def pop_cond() -> None:
+ nonlocal stack
+ key = frozenset(stack[-1][0].elems[-1].in_versions)
+ if key in stack[-2][1].cond:
+ stack[-2][1].cond[key].add(stack[-1][1])
+ else:
+ stack[-2][1].cond[key] = stack[-1][1]
+ stack = stack[:-1]
+
+ def pop_rep() -> None:
+ nonlocal stack
+ member_path = stack[-1][0]
+ member = member_path.elems[-1]
+ assert member.cnt
+ cnt_path = member_path.parent().add(member.cnt)
+ stack[-2][1].rep.append((cnt_path, stack[-1][1]))
+ stack = stack[:-1]
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
+ nonlocal recurse
+
+ ret = recurse(path)
+ if ret != idlutil.WalkCmd.KEEP_GOING:
+ return ret, None
+
+ nonlocal stack
+ stack_len = len(stack)
+
+ def pop() -> None:
+ nonlocal stack
+ nonlocal stack_len
+ while len(stack) > stack_len:
+ stack[-1][2]()
+
+ if path.elems:
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ stack.append((path, OffsetExpr(), pop_cond))
+ if child.cnt:
+ stack.append((path, OffsetExpr(), pop_rep))
+ if not isinstance(child.typ, idl.Struct):
+ assert child.typ.static_size
+ stack[-1][1].static += child.typ.static_size
+ return ret, pop
+
+ stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)]
+ idlutil.walk(typ, handle)
+ return stack[0][1]
+
+ def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
+ return idlutil.WalkCmd.KEEP_GOING
+
+ def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
+ def ret(path: idlutil.Path) -> idlutil.WalkCmd:
+ if len(path.elems) == 1 and path.elems[0].membname == name:
+ return idlutil.WalkCmd.ABORT
+ return idlutil.WalkCmd.KEEP_GOING
+
+ return ret
+
+ for typ in typs:
+ if not (
+ isinstance(typ, idl.Message) or typ.typname == "stat"
+ ): # SPECIAL (include stat)
+ continue
+ assert isinstance(typ, idl.Struct)
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n"
+
+ # Pass 1 - check size
+ max_size = max(typ.max_size(v) for v in typ.in_versions)
+
+ if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e)
+ ret += get_offset_expr(typ, go_to_end).gen_c(
+ "uint64_t", "needed_size", "val->", 1, 0
+ )
+ ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n"
+ else:
+ ret += get_offset_expr(typ, go_to_end).gen_c(
+ "uint32_t", "needed_size", "val->", 1, 0
+ )
+ ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
+ if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
+ ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
+ ret += f'\t\t\t"{typ.typname}",\n'
+ ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
+ ret += "\t\t\tctx->ctx->max_msg_size);\n"
+ ret += "\t\treturn true;\n"
+ ret += "\t}\n"
+
+ # Pass 2 - write data
+ ifdef_depth = 1
+ stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)]
+
+ def handle(
+ path: idlutil.Path,
+ ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal ifdef_depth
+ nonlocal stack
+ stack_len = len(stack)
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal ifdef_depth
+ nonlocal stack
+ nonlocal stack_len
+ while len(stack) > stack_len:
+ ret += f"{'\t'*(len(stack)-1)}}}\n"
+ if stack[-1][1]:
+ ifdef_depth -= 1
+ ret += cutil.ifdef_pop(ifdef_depth)
+ stack = stack[:-1]
+
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ struct = path.elems[-1].typ if path.elems else path.root
+ if isinstance(struct, idl.Struct):
+ offsets: list[str] = []
+ for member in struct.members:
+ if not member.val:
+ continue
+ for tok in member.val.tokens:
+ if not isinstance(tok, idl.ExprSym):
+ continue
+ if tok.symname == "end" or tok.symname.startswith("&"):
+ if tok.symname not in offsets:
+ offsets.append(tok.symname)
+ for name in offsets:
+ name_prefix = "offsetof_" + "".join(
+ m.membname + "_" for m in path.elems
+ )
+ if name == "end":
+ if not path.elems:
+ nonlocal max_size
+ if max_size > cutil.UINT32_MAX:
+ ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
+ else:
+ ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n"
+ continue
+ recurse: OffsetExprRecursion = go_to_end
+ else:
+ assert name.startswith("&")
+ name = name[1:]
+ recurse = go_to_tok(name)
+ expr = get_offset_expr(struct, recurse)
+ expr_prefix = path.c_str("val->", loopdepth)
+ if not expr_prefix.endswith(">"):
+ expr_prefix += "."
+ ret += expr.gen_c(
+ "uint32_t",
+ name_prefix + name,
+ expr_prefix,
+ len(stack),
+ loopdepth,
+ )
+ if path.elems:
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ ret += cutil.ifdef_push(
+ ifdef_depth + 1, c9util.ver_ifdef(child.in_versions)
+ )
+ ifdef_depth += 1
+ ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n"
+ stack.append((path, True))
+ if child.cnt:
+ cnt_path = path.parent().add(child.cnt)
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ if path.root.typname == "stat": # SPECIAL (stat)
+ ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ else:
+ ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+ loopvar = chr(ord("i") + loopdepth - 1)
+ ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
+ stack.append((path, False))
+ if not isinstance(child.typ, idl.Struct):
+ if child.val:
+
+ def lookup_sym(sym: str) -> str:
+ nonlocal path
+ if sym.startswith("&"):
+ sym = sym[1:]
+ return (
+ "offsetof_"
+ + "".join(m.membname + "_" for m in path.elems[:-1])
+ + sym
+ )
+
+ val = c9util.idl_expr(child.val, lookup_sym)
+ else:
+ val = path.c_str("val->")
+ if isinstance(child.typ, idl.Bitfield):
+ val += f" & {child.typ.typname}_masks[ctx->ctx->version]"
+ ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
+ return idlutil.WalkCmd.KEEP_GOING, pop
+
+ idlutil.walk(typ, handle)
+ del handle
+ del stack
+ del max_size
+
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ ret += cutil.ifdef_pop(0)
+ return ret
diff --git a/lib9p/protogen/c_unmarshal.py b/lib9p/protogen/c_unmarshal.py
new file mode 100644
index 0000000..e17f456
--- /dev/null
+++ b/lib9p/protogen/c_unmarshal.py
@@ -0,0 +1,92 @@
+# lib9p/protogen/c_unmarshal.py - Generate C unmarshal functions
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c_unmarshal"]
+
+
+def gen_c_unmarshal(versions: set[str], typs: list[idl.UserType]) -> str:
+ ret = """
+/* unmarshal_* ****************************************************************/
+
+LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
+\t*out = ctx->net_bytes[ctx->net_offset];
+\tctx->net_offset += 1;
+}
+
+LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
+\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]);
+\tctx->net_offset += 2;
+}
+
+LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
+\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]);
+\tctx->net_offset += 4;
+}
+
+LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
+\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]);
+\tctx->net_offset += 8;
+}
+"""
+ for typ in idlutil.topo_sorted(typs):
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = (
+ c9util.arg_unused
+ if (isinstance(typ, idl.Struct) and not typ.members)
+ else c9util.arg_used
+ )
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n"
+ match typ:
+ case idl.Number():
+ ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n"
+ case idl.Bitfield():
+ ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n"
+ case idl.Struct():
+ ret += "\tmemset(out, 0, sizeof(*out));\n"
+
+ for member in typ.members:
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ if member.val:
+ ret += f"\tctx->net_offset += {member.static_size};\n"
+ continue
+ ret += "\t"
+
+ prefix = "\t"
+ if member.in_versions != typ.in_versions:
+ ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) "
+ prefix = "\t\t"
+ if member.cnt:
+ if member.in_versions != typ.in_versions:
+ ret += "{\n"
+ ret += prefix
+ if member.typ.static_size == 1: # SPECIAL (string, zerocopy)
+ ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n"
+ ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n"
+ else:
+ ret += f"out->{member.membname} = ctx->extra;\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n"
+ ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n"
+ if member.in_versions != typ.in_versions:
+ ret += "\t}\n"
+ else:
+ ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n"
+ ret += cutil.ifdef_pop(1)
+ ret += "}\n"
+ ret += cutil.ifdef_pop(0)
+ return ret
diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py
new file mode 100644
index 0000000..a3f4348
--- /dev/null
+++ b/lib9p/protogen/c_validate.py
@@ -0,0 +1,171 @@
+# lib9p/protogen/c_validate.py - Generate C validation functions
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+
+# pylint: disable=unused-variable
+__all__ = ["gen_c_validate"]
+
+
+def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
+ return bool(member.max or member.val or any(m.cnt == member for m in typ.members))
+
+
+def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
+ ret = """
+/* validate_* *****************************************************************/
+
+LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
+\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
+\t\t/* If needed-net-size overflowed uint32_t, then
+\t\t * there's no way that actual-net-size will live up to
+\t\t * that. */
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\tif (ctx->net_offset > ctx->net_size)
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\treturn false;
+}
+
+LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
+\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
+\t\t/* If needed-host-size overflowed size_t, then there's
+\t\t * no way that actual-net-size will live up to
+\t\t * that. */
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\treturn false;
+}
+
+LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
+ size_t cnt,
+ _validate_fn_t item_fn, size_t item_host_size) {
+\tfor (size_t i = 0; i < cnt; i++)
+\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
+\t\t\treturn true;
+\treturn false;
+}
+
+LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
+LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
+LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
+LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
+"""
+
+ for typ in idlutil.topo_sorted(typs):
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = (
+ c9util.arg_unused
+ if (isinstance(typ, idl.Struct) and not typ.members)
+ else c9util.arg_used
+ )
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+ ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
+
+ match typ:
+ case idl.Number():
+ ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
+ case idl.Bitfield():
+ ret += f"\t if (validate_{typ.static_size}(ctx))\n"
+ ret += "\t\treturn true;\n"
+ ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
+ if typ.static_size == 1:
+ ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
+ else:
+ ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ ret += "\tif (val & ~mask)\n"
+ ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
+ ret += "\treturn false;\n"
+ case idl.Struct(): # and idl.Message()
+ if len(typ.members) == 0:
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ # Pass 1 - declare value variables
+ for member in typ.members:
+ if should_save_value(typ, member):
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
+ ret += cutil.ifdef_pop(1)
+
+ # Pass 2 - declare offset variables
+ mark_offset: set[str] = set()
+ for member in typ.members:
+ for tok in [*member.max.tokens, *member.val.tokens]:
+ if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
+ if tok.symname[1:] not in mark_offset:
+ ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
+ mark_offset.add(tok.symname[1:])
+
+ # Pass 3 - main pass
+ ret += "\treturn false\n"
+ for member in typ.members:
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += "\t || "
+ if member.in_versions != typ.in_versions:
+ ret += "( " + c9util.ver_cond(member.in_versions) + " && "
+ if member.cnt is not None:
+ if member.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret += f"_validate_size_net(ctx, {member.cnt.membname})"
+ else:
+ ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
+ if typ.typname == "s": # SPECIAL (string)
+ ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
+ else:
+ if should_save_value(typ, member):
+ ret += "("
+ if member.membname in mark_offset:
+ ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
+ ret += f"validate_{member.typ.typname}(ctx)"
+ if member.membname in mark_offset:
+ ret += "; })"
+ if should_save_value(typ, member):
+ nbytes = member.static_size
+ assert nbytes
+ if nbytes == 1:
+ ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
+ else:
+ ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
+ if member.in_versions != typ.in_versions:
+ ret += " )"
+ ret += "\n"
+
+ # Pass 4 - validate ,max= and ,val= constraints
+ for member in typ.members:
+
+ def lookup_sym(sym: str) -> str:
+ match sym:
+ case "end":
+ return "ctx->net_offset"
+ case _:
+ assert sym.startswith("&")
+ return f"_{sym[1:]}_offset"
+
+ if member.max:
+ assert member.static_size
+ nbits = member.static_size * 8
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
+ if member.val:
+ assert member.static_size
+ nbits = member.static_size * 8
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
+
+ ret += cutil.ifdef_pop(1)
+ ret += "\t ;\n"
+ ret += "}\n"
+ ret += cutil.ifdef_pop(0)
+ return ret
diff --git a/lib9p/protogen/cutil.py b/lib9p/protogen/cutil.py
new file mode 100644
index 0000000..a78cd17
--- /dev/null
+++ b/lib9p/protogen/cutil.py
@@ -0,0 +1,84 @@
+# lib9p/protogen/cutil.py - Utilities for generating C code
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+# pylint: disable=unused-variable
+__all__ = [
+ "UINT32_MAX",
+ "UINT64_MAX",
+ "macro",
+ "ifdef_init",
+ "ifdef_push",
+ "ifdef_pop",
+ "ifdef_leaf_is_noop",
+]
+
+UINT32_MAX = (1 << 32) - 1
+UINT64_MAX = (1 << 64) - 1
+
+
+def tab_ljust(s: str, width: int) -> str:
+ cur = len(s.expandtabs(tabsize=8))
+ if cur >= width:
+ return s
+ return s + " " * (width - cur)
+
+
+def macro(full: str) -> str:
+ full = full.rstrip()
+ assert "\n" in full
+ lines = [l.rstrip() for l in full.split("\n")]
+ width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1])
+ lines = [tab_ljust(l, width) for l in lines]
+ return " \\\n".join(lines).rstrip() + "\n"
+
+
+_ifdef_stack: list[str | None] = []
+
+
+def ifdef_init() -> None:
+ global _ifdef_stack
+ _ifdef_stack = []
+
+
+def ifdef_push(n: int, _newval: str) -> str:
+ # Grow the stack as needed
+ while len(_ifdef_stack) < n:
+ _ifdef_stack.append(None)
+
+ # Set some variables
+ parentval: str | None = None
+ for x in _ifdef_stack[:-1]:
+ if x is not None:
+ parentval = x
+ oldval = _ifdef_stack[-1]
+ newval: str | None = _newval
+ if newval == parentval:
+ newval = None
+
+ # Put newval on the stack.
+ _ifdef_stack[-1] = newval
+
+ # Build output.
+ ret = ""
+ if newval != oldval:
+ if oldval is not None:
+ ret += f"#endif /* {oldval} */\n"
+ if newval is not None:
+ ret += f"#if {newval}\n"
+ return ret
+
+
+def ifdef_pop(n: int) -> str:
+ global _ifdef_stack
+ ret = ""
+ while len(_ifdef_stack) > n:
+ if _ifdef_stack[-1] is not None:
+ ret += f"#endif /* {_ifdef_stack[-1]} */\n"
+ _ifdef_stack = _ifdef_stack[:-1]
+ return ret
+
+
+def ifdef_leaf_is_noop() -> bool:
+ return not _ifdef_stack[-1]
diff --git a/lib9p/protogen/h.py b/lib9p/protogen/h.py
new file mode 100644
index 0000000..7785ca1
--- /dev/null
+++ b/lib9p/protogen/h.py
@@ -0,0 +1,447 @@
+# lib9p/protogen/h.py - Generate 9p.generated.h
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import sys
+import typing
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+# pylint: disable=unused-variable
+__all__ = ["gen_h"]
+
+# get_buffer_size() ############################################################
+
+
+class BufferSize(typing.NamedTuple):
+ min_size: int # really just here to sanity-check against typ.min_size(version)
+ exp_size: int # "expected" or max-reasonable size
+ max_size: int # really just here to sanity-check against typ.max_size(version)
+ max_copy: int
+ max_copy_extra: str
+ max_iov: int
+ max_iov_extra: str
+
+
+class TmpBufferSize:
+ min_size: int
+ exp_size: int
+ max_size: int
+ max_copy: int
+ max_copy_extra: str
+ max_iov: int
+ max_iov_extra: str
+
+ tmp_starts_with_copy: bool
+ tmp_ends_with_copy: bool
+
+ def __init__(self) -> None:
+ self.min_size = 0
+ self.exp_size = 0
+ self.max_size = 0
+ self.max_copy = 0
+ self.max_copy_extra = ""
+ self.max_iov = 0
+ self.max_iov_extra = ""
+ self.tmp_starts_with_copy = False
+ self.tmp_ends_with_copy = False
+
+
+def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize:
+ assert isinstance(typ, idl.Primitive) or (version in typ.in_versions)
+
+ ret = TmpBufferSize()
+
+ if not isinstance(typ, idl.Struct):
+ assert typ.static_size
+ ret.min_size = typ.static_size
+ ret.exp_size = typ.static_size
+ ret.max_size = typ.static_size
+ ret.max_copy = typ.static_size
+ ret.max_iov = 1
+ ret.tmp_starts_with_copy = True
+ ret.tmp_ends_with_copy = True
+ return ret
+
+ def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]:
+ nonlocal ret
+ if path.elems:
+ child = path.elems[-1]
+ if version not in child.in_versions:
+ return idlutil.WalkCmd.DONT_RECURSE, None
+ if child.cnt:
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret.max_iov += 1
+ # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data
+ ret.exp_size += 27 if child.membname == "utf8" else 8192
+ ret.max_size += child.max_cnt
+ ret.tmp_ends_with_copy = False
+ return idlutil.WalkCmd.DONT_RECURSE, None
+ sub = _get_buffer_size(child.typ, version)
+ ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM
+ ret.max_size += sub.max_size * child.max_cnt
+ if child.membname == "wname" and path.root.typname in (
+ "Tsread",
+ "Tswrite",
+ ): # SPECIAL (9P2000.e)
+ assert ret.tmp_ends_with_copy
+ assert sub.tmp_starts_with_copy
+ assert not sub.tmp_ends_with_copy
+ ret.max_copy_extra = (
+ f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})"
+ )
+ ret.max_iov_extra = (
+ f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})"
+ )
+ ret.max_iov -= 1
+ else:
+ ret.max_copy += sub.max_copy * child.max_cnt
+ if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy
+ ret.max_iov += 1
+ else: # contains zero-copy segments
+ ret.max_iov += sub.max_iov * child.max_cnt
+ if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy:
+ # we can merge this one
+ ret.max_iov -= 1
+ if (
+ sub.tmp_ends_with_copy
+ and sub.tmp_starts_with_copy
+ and sub.max_iov > 1
+ ):
+ # we can merge these
+ ret.max_iov -= child.max_cnt - 1
+ ret.tmp_ends_with_copy = sub.tmp_ends_with_copy
+ return idlutil.WalkCmd.DONT_RECURSE, None
+ if not isinstance(child.typ, idl.Struct):
+ assert child.typ.static_size
+ if not ret.tmp_ends_with_copy:
+ if ret.max_size == 0:
+ ret.tmp_starts_with_copy = True
+ ret.max_iov += 1
+ ret.tmp_ends_with_copy = True
+ ret.min_size += child.typ.static_size
+ ret.exp_size += child.typ.static_size
+ ret.max_size += child.typ.static_size
+ ret.max_copy += child.typ.static_size
+ return idlutil.WalkCmd.KEEP_GOING, None
+
+ idlutil.walk(typ, handle)
+ assert ret.min_size == typ.min_size(version)
+ assert ret.max_size == typ.max_size(version)
+ return ret
+
+
+def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
+ tmp = _get_buffer_size(typ, version)
+ return BufferSize(
+ min_size=tmp.min_size,
+ exp_size=tmp.exp_size,
+ max_size=tmp.max_size,
+ max_copy=tmp.max_copy,
+ max_copy_extra=tmp.max_copy_extra,
+ max_iov=tmp.max_iov,
+ max_iov_extra=tmp.max_iov_extra,
+ )
+
+
+# Generate .h ##################################################################
+
+
+def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
+ cutil.ifdef_init()
+
+ ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
+
+#ifndef _LIB9P_9P_H_
+\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
+#endif
+
+#include <stdint.h> /* for uint{{n}}_t types */
+
+#include <libhw/generic/net.h> /* for struct iovec */
+"""
+
+ id2typ: dict[int, idl.Message] = {}
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
+ id2typ[msg.msgid] = msg
+
+ ret += """
+/* config *********************************************************************/
+
+#include "config.h"
+"""
+ for ver in sorted(versions):
+ ret += "\n"
+ ret += f"#ifndef {c9util.ver_ifdef({ver})}\n"
+ ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n"
+ if ver == "9P2000.e": # SPECIAL (9P2000.e)
+ ret += "#else\n"
+ ret += f"\t#if {c9util.ver_ifdef({ver})}\n"
+ ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n"
+ ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
+ ret += "\t\t#endif\n"
+ ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
+ ret += "\t#endif\n"
+ ret += "#endif\n"
+
+ ret += f"""
+/* enum version ***************************************************************/
+
+enum {c9util.ident('version')} {{
+"""
+ fullversions = ["unknown = 0", *sorted(versions)]
+ verwidth = max(len(v) for v in fullversions)
+ for ver in fullversions:
+ if ver in versions:
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+ ret += f"\t{c9util.ver_enum(ver)},"
+ ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
+ ret += cutil.ifdef_pop(0)
+ ret += f"\t{c9util.ver_enum('NUM')},\n"
+ ret += "};\n"
+
+ ret += """
+/* enum msg_type **************************************************************/
+
+"""
+ ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n"
+ namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message))
+ for n in range(0x100):
+ if n not in id2typ:
+ continue
+ msg = id2typ[n]
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions))
+ ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n"
+ ret += cutil.ifdef_pop(0)
+ ret += "};\n"
+
+ ret += """
+/* payload types **************************************************************/
+"""
+
+ def per_version_comment(
+ typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str]
+ ) -> str:
+ lines: dict[str, str] = {}
+ for version in sorted(typ.in_versions):
+ lines[version] = fn(typ, version)
+ if len(set(lines.values())) == 1:
+ for _, line in lines.items():
+ return f"/* {line} */\n"
+ assert False
+ else:
+ ret = ""
+ v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions)
+ for version, line in lines.items():
+ ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n"
+ return ret
+
+ for typ in idlutil.topo_sorted(typs):
+ ret += "\n"
+ ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+
+ def sum_size(typ: idl.UserType, version: str) -> str:
+ sz = get_buffer_size(typ, version)
+ assert (
+ sz.min_size <= sz.exp_size
+ and sz.exp_size <= sz.max_size
+ and sz.max_size < cutil.UINT64_MAX
+ )
+ ret = ""
+ if sz.min_size == sz.max_size:
+ ret += f"size = {sz.min_size:,}"
+ else:
+ ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}"
+ if sz.max_size > cutil.UINT32_MAX:
+ ret += " (warning: >UINT32_MAX)"
+ ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}"
+ return ret
+
+ ret += per_version_comment(typ, sum_size)
+
+ match typ:
+ case idl.Number():
+ ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"
+ prefix = f"{c9util.IDENT(typ.typname)}_"
+ namewidth = max(len(name) for name in typ.vals)
+ for name, val in typ.vals.items():
+ ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
+ case idl.Bitfield():
+ ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"
+
+ def bitname(val: idl.Bit | idl.BitAlias) -> str:
+ s = val.bitname
+ match val:
+ case idl.Bit(cat=idl.BitCat.RESERVED):
+ s = "_RESERVED_" + s
+ case idl.Bit(cat=idl.BitCat.SUBFIELD):
+ assert isinstance(typ, idl.Bitfield)
+ n = sum(
+ 1
+ for b in typ.bits[: val.num]
+ if b.cat == idl.BitCat.SUBFIELD
+ and b.bitname == val.bitname
+ )
+ s = f"_{s}_{n}"
+ case idl.Bit(cat=idl.BitCat.UNUSED):
+ return ""
+ return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s))
+
+ namewidth = max(
+ len(bitname(val)) for val in [*typ.bits, *typ.names.values()]
+ )
+
+ ret += "\n"
+ for bit in reversed(typ.bits):
+ vers = bit.in_versions
+ if bit.cat == idl.BitCat.UNUSED:
+ vers = typ.in_versions
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers))
+
+ # It is important all of the `beg` strings have
+ # the same length.
+ end = ""
+ match bit.cat:
+ case (
+ idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD
+ ):
+ if cutil.ifdef_leaf_is_noop():
+ beg = "#define "
+ else:
+ beg = "# define"
+ case idl.BitCat.UNUSED:
+ beg = "/* unused"
+ end = " */"
+
+ c_name = bitname(bit)
+ c_val = f"1<<{bit.num}"
+ ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n"
+ if aliases := [
+ alias
+ for alias in typ.names.values()
+ if isinstance(alias, idl.BitAlias)
+ ]:
+ ret += "\n"
+
+ for alias in aliases:
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions))
+
+ end = ""
+ if cutil.ifdef_leaf_is_noop():
+ beg = "#define "
+ else:
+ beg = "# define"
+
+ c_name = bitname(alias)
+ c_val = alias.val
+ ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n"
+ ret += cutil.ifdef_pop(1)
+ del bitname
+ case idl.Struct(): # and idl.Message():
+ ret += c9util.typename(typ) + " {"
+ if not typ.members:
+ ret += "};\n"
+ continue
+ ret += "\n"
+
+ typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members)
+
+ for member in typ.members:
+ if member.val:
+ continue
+ ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+ ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n"
+ ret += cutil.ifdef_pop(1)
+ ret += "};\n"
+ del typ
+ ret += cutil.ifdef_pop(0)
+
+ ret += """
+/* containers *****************************************************************/
+"""
+ ret += "\n"
+ ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n"
+
+ tmsg_max_iov: dict[str, int] = {}
+ tmsg_max_copy: dict[str, int] = {}
+ rmsg_max_iov: dict[str, int] = {}
+ rmsg_max_copy: dict[str, int] = {}
+ for typ in typs:
+ if not isinstance(typ, idl.Message):
+ continue
+ if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e)
+ continue
+ max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov
+ max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy
+ for version in typ.in_versions:
+ if version not in max_iov:
+ max_iov[version] = 0
+ max_copy[version] = 0
+ sz = get_buffer_size(typ, version)
+ if sz.max_iov > max_iov[version]:
+ max_iov[version] = sz.max_iov
+ if sz.max_copy > max_copy[version]:
+ max_copy[version] = sz.max_copy
+
+ for name, table in [
+ ("tmsg_max_iov", tmsg_max_iov),
+ ("tmsg_max_copy", tmsg_max_copy),
+ ("rmsg_max_iov", rmsg_max_iov),
+ ("rmsg_max_copy", rmsg_max_copy),
+ ]:
+ inv: dict[int, set[str]] = {}
+ for version, maxval in table.items():
+ if maxval not in inv:
+ inv[maxval] = set()
+ inv[maxval].add(version)
+
+ ret += "\n"
+ directive = "if"
+ seen_e = False # SPECIAL (9P2000.e)
+ for maxval in sorted(inv, reverse=True):
+ ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n"
+ indent = 1
+ if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
+ typ = next(typ for typ in typs if typ.typname == "Tswrite")
+ sz = get_buffer_size(typ, "9P2000.e")
+ match name:
+ case "tmsg_max_iov":
+ maxexpr = f"{sz.max_iov}{sz.max_iov_extra}"
+ case "tmsg_max_copy":
+ maxexpr = f"{sz.max_copy}{sz.max_copy_extra}"
+ case _:
+ assert False
+ ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n"
+ ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n"
+ ret += "\t#else\n"
+ indent += 1
+ ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n"
+ if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
+ ret += "\t#endif\n"
+ if "9P2000.e" in inv[maxval]:
+ seen_e = True
+ directive = "elif"
+ ret += "#endif\n"
+
+ ret += "\n"
+ ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n"
+ ret += "\tsize_t iov_cnt;\n"
+ ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n"
+ ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n"
+ ret += "};\n"
+
+ ret += "\n"
+ ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n"
+ ret += "\tsize_t iov_cnt;\n"
+ ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n"
+ ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n"
+ ret += "};\n"
+
+ return ret
diff --git a/lib9p/protogen/idlutil.py b/lib9p/protogen/idlutil.py
new file mode 100644
index 0000000..dc4d012
--- /dev/null
+++ b/lib9p/protogen/idlutil.py
@@ -0,0 +1,112 @@
+# lib9p/protogen/idlutil.py - Utilities for working with the 9P idl package
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import enum
+import graphlib
+import typing
+
+import idl
+
+# pylint: disable=unused-variable
+__all__ = [
+ "topo_sorted",
+ "Path",
+ "WalkCmd",
+ "WalkHandler",
+ "walk",
+]
+
+# topo_sorted() ################################################################
+
+
+def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
+ ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter()
+ for typ in typs:
+ match typ:
+ case idl.Number():
+ ts.add(typ)
+ case idl.Bitfield():
+ ts.add(typ)
+ case idl.Struct(): # and idl.Message():
+ deps = [
+ member.typ
+ for member in typ.members
+ if not isinstance(member.typ, idl.Primitive)
+ ]
+ ts.add(typ, *deps)
+ return ts.static_order()
+
+
+# walk() #######################################################################
+
+
+class Path:
+ root: idl.Type
+ elems: list[idl.StructMember]
+
+ def __init__(
+ self, root: idl.Type, elems: list[idl.StructMember] | None = None
+ ) -> None:
+ self.root = root
+ self.elems = elems if elems is not None else []
+
+ def add(self, elem: idl.StructMember) -> "Path":
+ return Path(self.root, self.elems + [elem])
+
+ def parent(self) -> "Path":
+ return Path(self.root, self.elems[:-1])
+
+ def c_str(self, base: str, loopdepth: int = 0) -> str:
+ ret = base
+ for i, elem in enumerate(self.elems):
+ if i > 0:
+ ret += "."
+ ret += elem.membname
+ if elem.cnt:
+ ret += f"[{chr(ord('i')+loopdepth)}]"
+ loopdepth += 1
+ return ret
+
+ def __str__(self) -> str:
+ return self.c_str(self.root.typname + "->")
+
+
+class WalkCmd(enum.Enum):
+ KEEP_GOING = 1
+ DONT_RECURSE = 2
+ ABORT = 3
+
+
+type WalkHandler = typing.Callable[
+ [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
+]
+
+
+def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
+ typ = path.elems[-1].typ if path.elems else path.root
+
+ ret, atexit = handle(path)
+
+ if isinstance(typ, idl.Struct):
+ match ret:
+ case WalkCmd.KEEP_GOING:
+ for member in typ.members:
+ if _walk(path.add(member), handle) == WalkCmd.ABORT:
+ ret = WalkCmd.ABORT
+ break
+ case WalkCmd.DONT_RECURSE:
+ ret = WalkCmd.KEEP_GOING
+ case WalkCmd.ABORT:
+ ret = WalkCmd.ABORT
+ case _:
+ assert False, f"invalid cmd: {ret}"
+
+ if atexit:
+ atexit()
+ return ret
+
+
+def walk(typ: idl.Type, handle: WalkHandler) -> None:
+ _walk(Path(typ), handle)