summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen1853
1 files changed, 1025 insertions, 828 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index ec42cfd..f2b4f13 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -2,493 +2,108 @@
# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files
# defining 9P protocol variants.
#
-# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
import enum
+import graphlib
import os.path
-import re
-from abc import ABC, abstractmethod
-from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast
+import sys
+import typing
+
+sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))
+
+import idl
# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".
-# Types ########################################################################
-
-
-class Primitive(enum.Enum):
- u8 = 1
- u16 = 2
- u32 = 4
- u64 = 8
-
- @property
- def in_versions(self) -> set[str]:
- return set()
-
- @property
- def name(self) -> str:
- return str(self.value)
-
- @property
- def static_size(self) -> int:
- return self.value
-
-
-class Number:
- name: str
- in_versions: set[str]
-
- prim: Primitive
-
- def __init__(self) -> None:
- self.in_versions = set()
-
- @property
- def static_size(self) -> int:
- return self.prim.static_size
-
-
-class BitfieldVal:
- name: str
- in_versions: set[str]
-
- val: str
-
- def __init__(self) -> None:
- self.in_versions = set()
-
-
-class Bitfield:
- name: str
- in_versions: set[str]
-
- prim: Primitive
-
- bits: list[str] # bitnames
- names: dict[str, BitfieldVal] # bits *and* aliases
-
- def __init__(self) -> None:
- self.in_versions = set()
- self.names = {}
-
- @property
- def static_size(self) -> int:
- return self.prim.static_size
-
- def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool:
- """Return whether the given bit is valid in the given protocol
- version.
-
- """
- bitname = self.bits[bit] if isinstance(bit, int) else bit
- assert bitname in self.bits
- if not bitname:
- return False
- if bitname.startswith("_"):
- return False
- if ver and (ver not in self.names[bitname].in_versions):
- return False
- return True
-
-
-class ExprLit:
- val: int
-
- def __init__(self, val: int) -> None:
- self.val = val
-
-
-class ExprSym:
- name: str
-
- def __init__(self, name: str) -> None:
- self.name = name
-
-
-class ExprOp:
- op: Literal["-", "+"]
-
- def __init__(self, op: Literal["-", "+"]) -> None:
- self.op = op
-
-class Expr:
- tokens: list[ExprLit | ExprSym | ExprOp]
+# Utilities ####################################################################
- def __init__(self) -> None:
- self.tokens = []
-
- def __bool__(self) -> bool:
- return len(self.tokens) > 0
-
-
-class StructMember:
- # from left-to-right when parsing
- cnt: str | None = None
- name: str
- typ: "Type"
- max: Expr
- val: Expr
-
- in_versions: set[str]
-
- @property
- def static_size(self) -> int | None:
- if self.cnt:
- return None
- return self.typ.static_size
-
-
-class Struct:
- name: str
- in_versions: set[str]
-
- members: list[StructMember]
-
- def __init__(self) -> None:
- self.in_versions = set()
-
- @property
- def static_size(self) -> int | None:
- size = 0
- for member in self.members:
- msize = member.static_size
- if msize is None:
- return None
- size += msize
- return size
-
-
-class Message(Struct):
- @property
- def msgid(self) -> int:
- assert len(self.members) >= 3
- assert self.members[1].name == "typ"
- assert self.members[1].static_size == 1
- assert self.members[1].val
- assert len(self.members[1].val.tokens) == 1
- assert isinstance(self.members[1].val.tokens[0], ExprLit)
- return self.members[1].val.tokens[0].val
-
-
-Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message
-# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13
-T = TypeVar("T", Number, Bitfield, Struct, Message)
-
-# Parse *.9p ###################################################################
-
-re_priname = "(?:1|2|4|8)" # primitive names
-re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names
-re_impname = r"(?:\*|" + re_symname + ")" # names we can import
-re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be
-
-re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be
-
-re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)"
-
-re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})"
-re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"
-
-re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?"
-
-
-def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None:
- spec = spec.strip()
-
- bit: int | None
- val: BitfieldVal
- if m := re.fullmatch(re_bitspec_bit, spec):
- bit = int(m.group("bit"))
- name = m.group("name")
-
- val = BitfieldVal()
- val.name = name
- val.val = f"1<<{bit}"
- val.in_versions.add(ver)
-
- if bit < 0 or bit >= len(bf.bits):
- raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
- if bf.bits[bit]:
- raise ValueError(f"{bf.name}: bit {bit} already assigned")
- bf.bits[bit] = val.name
- elif m := re.fullmatch(re_bitspec_alias, spec):
- name = m.group("name")
- valstr = m.group("val")
-
- val = BitfieldVal()
- val.name = name
- val.val = valstr
- val.in_versions.add(ver)
- else:
- raise SyntaxError(f"invalid bitfield spec {repr(spec)}")
-
- if val.name in bf.names:
- raise ValueError(f"{bf.name}: name {val.name} already assigned")
- bf.names[val.name] = val
-
-
-def parse_expr(expr: str) -> Expr:
- assert re.fullmatch(re_expr, expr)
- ret = Expr()
- for tok in re.split("([-+])", expr):
- if tok == "-" or tok == "+":
- # I, for the life of me, do not understand why I need this
- # cast() to keep mypy happy.
- ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))]
- elif re.fullmatch("[0-9]+", tok):
- ret.tokens += [ExprLit(int(tok))]
- else:
- ret.tokens += [ExprSym(tok)]
- return ret
-
-
-def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None:
- for spec in specs.split():
- m = re.fullmatch(re_memberspec, spec)
- if not m:
- raise SyntaxError(f"invalid member spec {repr(spec)}")
-
- member = StructMember()
- member.in_versions = {ver}
-
- member.name = m.group("name")
- if any(x.name == member.name for x in struct.members):
- raise ValueError(f"duplicate member name {repr(member.name)}")
-
- if m.group("typ") not in env:
- raise NameError(f"Unknown type {repr(m.group(2))}")
- member.typ = env[m.group("typ")]
-
- if cnt := m.group("cnt"):
- if len(struct.members) == 0 or struct.members[-1].name != cnt:
- raise ValueError(f"list count must be previous item: {repr(cnt)}")
- if not isinstance(struct.members[-1].typ, Primitive):
- raise ValueError(f"list count must be an integer type: {repr(cnt)}")
- member.cnt = cnt
+idprefix = "lib9p_"
- if maxstr := m.group("max"):
- if (not isinstance(member.typ, Primitive)) or member.cnt:
- raise ValueError("',max=' may only be specified on a non-repeated atom")
- member.max = parse_expr(maxstr)
- else:
- member.max = Expr()
+u32max = (1 << 32) - 1
+u64max = (1 << 64) - 1
- if valstr := m.group("val"):
- if (not isinstance(member.typ, Primitive)) or member.cnt:
- raise ValueError("',val=' may only be specified on a non-repeated atom")
- member.val = parse_expr(valstr)
- else:
- member.val = Expr()
-
- struct.members += [member]
-
-
-def re_string(grpname: str) -> str:
- return f'"(?P<{grpname}>[^"]*)"'
-
-
-re_line_version = f"version\\s+{re_string('version')}"
-re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)"
-re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
-re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
-re_line_bitfield_ = (
- f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}"
-)
-re_line_struct = (
- f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
-)
-re_line_msg = (
- f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
-)
-re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg
-
-
-def parse_file(
- filename: str, get_include: Callable[[str], tuple[str, list[Type]]]
-) -> tuple[str, list[Type]]:
- version: str | None = None
- env: dict[str, Type] = {
- "1": Primitive.u8,
- "2": Primitive.u16,
- "4": Primitive.u32,
- "8": Primitive.u64,
- }
-
- def get_type(name: str, tc: type[T]) -> T:
- nonlocal env
- if name not in env:
- raise NameError(f"Unknown type {repr(name)}")
- ret = env[name]
- if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__):
- raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}")
- return ret
- with open(filename, "r") as fh:
- prev: Type | None = None
- for line in fh:
- line = line.split("#", 1)[0].rstrip()
- if not line:
- continue
- if m := re.fullmatch(re_line_version, line):
- if version:
- raise SyntaxError("must have exactly 1 version line")
- version = m.group("version")
- continue
- if not version:
- raise SyntaxError("must have exactly 1 version line")
-
- if m := re.fullmatch(re_line_import, line):
- other_version, other_typs = get_include(m.group("file"))
- for symname in m.group("syms").split(sep=","):
- symname = symname.strip()
- for typ in other_typs:
- if typ.name == symname or symname == "*":
- match typ:
- case Primitive():
- pass
- case Number():
- typ.in_versions.add(version)
- case Bitfield():
- typ.in_versions.add(version)
- for val in typ.names.values():
- if other_version in val.in_versions:
- val.in_versions.add(version)
- case Struct(): # and Message()
- typ.in_versions.add(version)
- for member in typ.members:
- if other_version in member.in_versions:
- member.in_versions.add(version)
- env[typ.name] = typ
- elif m := re.fullmatch(re_line_num, line):
- num = Number()
- num.name = m.group("name")
- num.in_versions.add(version)
-
- prim = env[m.group("prim")]
- assert isinstance(prim, Primitive)
- num.prim = prim
-
- env[num.name] = num
- prev = num
- elif m := re.fullmatch(re_line_bitfield, line):
- bf = Bitfield()
- bf.name = m.group("name")
- bf.in_versions.add(version)
-
- prim = env[m.group("prim")]
- assert isinstance(prim, Primitive)
- bf.prim = prim
-
- bf.bits = (prim.static_size * 8) * [""]
-
- env[bf.name] = bf
- prev = bf
- elif m := re.fullmatch(re_line_bitfield_, line):
- bf = get_type(m.group("name"), Bitfield)
- parse_bitspec(version, bf, m.group("member"))
-
- prev = bf
- elif m := re.fullmatch(re_line_struct, line):
- match m.group("op"):
- case "=":
- struct = Struct()
- struct.name = m.group("name")
- struct.in_versions.add(version)
- struct.members = []
- parse_members(version, env, struct, m.group("members"))
-
- env[struct.name] = struct
- prev = struct
- case "+=":
- struct = get_type(m.group("name"), Struct)
- parse_members(version, env, struct, m.group("members"))
-
- prev = struct
- elif m := re.fullmatch(re_line_msg, line):
- match m.group("op"):
- case "=":
- msg = Message()
- msg.name = m.group("name")
- msg.in_versions.add(version)
- msg.members = []
- parse_members(version, env, msg, m.group("members"))
-
- env[msg.name] = msg
- prev = msg
- case "+=":
- msg = get_type(m.group("name"), Message)
- parse_members(version, env, msg, m.group("members"))
-
- prev = msg
- elif m := re.fullmatch(re_line_cont, line):
- match prev:
- case Bitfield():
- parse_bitspec(version, prev, m.group("specs"))
- case Struct(): # and Message()
- parse_members(version, env, prev, m.group("specs"))
- case _:
- raise SyntaxError(
- "continuation line must come after a bitfield, struct, or msg line"
- )
- else:
- raise SyntaxError(f"invalid line {repr(line)}")
- if not version:
- raise SyntaxError("must have exactly 1 version line")
-
- typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]
-
- for typ in [typ for typ in typs if isinstance(typ, Struct)]:
- valid_syms = ["end", *["&" + m.name for m in typ.members]]
- for member in typ.members:
- for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, ExprSym) and tok.name not in valid_syms:
- raise ValueError(
- f"{typ.name}.{member.name}: invalid sym: {tok.name}"
- )
+def tab_ljust(s: str, width: int) -> str:
+ cur = len(s.expandtabs(tabsize=8))
+ if cur >= width:
+ return s
+ return s + " " * (width - cur)
- return version, typs
+def add_prefix(p: str, s: str) -> str:
+ if s.startswith("_"):
+ return "_" + p + s[1:]
+ return p + s
-# Generate C ###################################################################
-idprefix = "lib9p_"
+def c_macro(full: str) -> str:
+ full = full.rstrip()
+ assert "\n" in full
+ lines = [l.rstrip() for l in full.split("\n")]
+ width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1])
+ lines = [tab_ljust(l, width) for l in lines]
+ return " \\\n".join(lines).rstrip() + "\n"
def c_ver_enum(ver: str) -> str:
return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
-def c_ver_ifdef(versions: set[str]) -> str:
+def c_ver_ifdef(versions: typing.Collection[str]) -> str:
return " || ".join(
- f"defined(CONFIG_9P_ENABLE_{v.replace('.', '_')})" for v in sorted(versions)
+ f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
)
-def c_ver_cond(versions: set[str]) -> str:
+def c_ver_cond(versions: typing.Collection[str]) -> str:
if len(versions) == 1:
- return f"(ctx->ctx->version=={c_ver_enum(next(v for v in versions))})"
+ v = next(v for v in versions)
+ return f"is_ver(ctx, {v.replace('.', '_')})"
return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-def c_typename(typ: Type) -> str:
+def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
match typ:
- case Primitive():
+ case idl.Primitive():
+ if typ.value == 1 and parent and parent.cnt: # SPECIAL (string)
+ return "[[gnu::nonstring]] char"
return f"uint{typ.value*8}_t"
- case Number():
+ case idl.Number():
return f"{idprefix}{typ.name}_t"
- case Bitfield():
+ case idl.Bitfield():
return f"{idprefix}{typ.name}_t"
- case Message():
+ case idl.Message():
return f"struct {idprefix}msg_{typ.name}"
- case Struct():
+ case idl.Struct():
return f"struct {idprefix}{typ.name}"
case _:
raise ValueError(f"not a type: {typ.__class__.__name__}")
+def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str:
+ ret: list[str] = []
+ for tok in expr.tokens:
+ match tok:
+ case idl.ExprOp():
+ ret.append(tok.op)
+ case idl.ExprLit():
+ ret.append(str(tok.val))
+ case idl.ExprSym(name="s32_max"):
+ ret.append("INT32_MAX")
+ case idl.ExprSym(name="s64_max"):
+ ret.append("INT64_MAX")
+ case idl.ExprSym():
+ ret.append(lookup_sym(tok.name))
+ case _:
+ assert False
+ return " ".join(ret)
+
+
_ifdef_stack: list[str | None] = []
@@ -496,7 +111,7 @@ def ifdef_push(n: int, _newval: str) -> str:
# Grow the stack as needed
global _ifdef_stack
while len(_ifdef_stack) < n:
- _ifdef_stack += [None]
+ _ifdef_stack.append(None)
# Set some variables
parentval: str | None = None
@@ -531,21 +146,253 @@ def ifdef_pop(n: int) -> str:
return ret
-def gen_h(versions: set[str], typs: list[Type]) -> str:
+# topo_sorted() ################################################################
+
+
+def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]:
+ ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter()
+ for typ in typs:
+ match typ:
+ case idl.Number():
+ ts.add(typ)
+ case idl.Bitfield():
+ ts.add(typ)
+ case idl.Struct(): # and idl.Message():
+ deps = [
+ member.typ
+ for member in typ.members
+ if not isinstance(member.typ, idl.Primitive)
+ ]
+ ts.add(typ, *deps)
+ return ts.static_order()
+
+
+# walk() #######################################################################
+
+
+class Path:
+ root: idl.Type
+ elems: list[idl.StructMember]
+
+ def __init__(
+ self, root: idl.Type, elems: list[idl.StructMember] | None = None
+ ) -> None:
+ self.root = root
+ self.elems = elems if elems is not None else []
+
+ def add(self, elem: idl.StructMember) -> "Path":
+ return Path(self.root, self.elems + [elem])
+
+ def parent(self) -> "Path":
+ return Path(self.root, self.elems[:-1])
+
+ def c_str(self, base: str, loopdepth: int = 0) -> str:
+ ret = base
+ for i, elem in enumerate(self.elems):
+ if i > 0:
+ ret += "."
+ ret += elem.name
+ if elem.cnt:
+ ret += f"[{chr(ord('i')+loopdepth)}]"
+ loopdepth += 1
+ return ret
+
+ def __str__(self) -> str:
+ return self.c_str(self.root.name + "->")
+
+
+class WalkCmd(enum.Enum):
+ KEEP_GOING = 1
+ DONT_RECURSE = 2
+ ABORT = 3
+
+
+type WalkHandler = typing.Callable[
+ [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
+]
+
+
+def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
+ typ = path.elems[-1].typ if path.elems else path.root
+
+ ret, atexit = handle(path)
+
+ if isinstance(typ, idl.Struct):
+ match ret:
+ case WalkCmd.KEEP_GOING:
+ for member in typ.members:
+ if _walk(path.add(member), handle) == WalkCmd.ABORT:
+ ret = WalkCmd.ABORT
+ break
+ case WalkCmd.DONT_RECURSE:
+ ret = WalkCmd.KEEP_GOING
+ case WalkCmd.ABORT:
+ ret = WalkCmd.ABORT
+ case _:
+ assert False, f"invalid cmd: {ret}"
+
+ if atexit:
+ atexit()
+ return ret
+
+
+def walk(typ: idl.Type, handle: WalkHandler) -> None:
+ _walk(Path(typ), handle)
+
+
+# get_buffer_size() ############################################################
+
+
+class BufferSize:
+ min_size: int # really just here to sanity-check against typ.min_size(version)
+ exp_size: int # "expected" or max-reasonable size
+ max_size: int # really just here to sanity-check against typ.max_size(version)
+ max_copy: int
+ max_copy_extra: str
+ max_iov: int
+ max_iov_extra: str
+ _starts_with_copy: bool
+ _ends_with_copy: bool
+
+ def __init__(self) -> None:
+ self.min_size = 0
+ self.exp_size = 0
+ self.max_size = 0
+ self.max_copy = 0
+ self.max_copy_extra = ""
+ self.max_iov = 0
+ self.max_iov_extra = ""
+ self._starts_with_copy = False
+ self._ends_with_copy = False
+
+
+def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
+ assert isinstance(typ, idl.Primitive) or (version in typ.in_versions)
+
+ ret = BufferSize()
+
+ if not isinstance(typ, idl.Struct):
+ assert typ.static_size
+ ret.min_size = typ.static_size
+ ret.exp_size = typ.static_size
+ ret.max_size = typ.static_size
+ ret.max_copy = typ.static_size
+ ret.max_iov = 1
+ ret._starts_with_copy = True
+ ret._ends_with_copy = True
+ return ret
+
+ def handle(path: Path) -> tuple[WalkCmd, None]:
+ nonlocal ret
+ if path.elems:
+ child = path.elems[-1]
+ if version not in child.in_versions:
+ return WalkCmd.DONT_RECURSE, None
+ if child.cnt:
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret.max_iov += 1
+ # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data
+ ret.exp_size += 27 if child.name == "utf8" else 8192
+ ret.max_size += child.max_cnt
+ ret._ends_with_copy = False
+ return WalkCmd.DONT_RECURSE, None
+ sub = get_buffer_size(child.typ, version)
+ ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM
+ ret.max_size += sub.max_size * child.max_cnt
+ if child.name == "wname" and path.root.name in (
+ "Tsread",
+ "Tswrite",
+ ): # SPECIAL (9P2000.e)
+ assert ret._ends_with_copy
+ assert sub._starts_with_copy
+ assert not sub._ends_with_copy
+ ret.max_copy_extra = (
+ f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})"
+ )
+ ret.max_iov_extra = (
+ f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})"
+ )
+ ret.max_iov -= 1
+ else:
+ ret.max_copy += sub.max_copy * child.max_cnt
+ if sub.max_iov == 1 and sub._starts_with_copy: # is purely copy
+ ret.max_iov += 1
+ else: # contains zero-copy segments
+ ret.max_iov += sub.max_iov * child.max_cnt
+ if ret._ends_with_copy and sub._starts_with_copy:
+ # we can merge this one
+ ret.max_iov -= 1
+ if (
+ sub._ends_with_copy
+ and sub._starts_with_copy
+ and sub.max_iov > 1
+ ):
+ # we can merge these
+ ret.max_iov -= child.max_cnt - 1
+ ret._ends_with_copy = sub._ends_with_copy
+ return WalkCmd.DONT_RECURSE, None
+ elif not isinstance(child.typ, idl.Struct):
+ assert child.typ.static_size
+ if not ret._ends_with_copy:
+ if ret.max_size == 0:
+ ret._starts_with_copy = True
+ ret.max_iov += 1
+ ret._ends_with_copy = True
+ ret.min_size += child.typ.static_size
+ ret.exp_size += child.typ.static_size
+ ret.max_size += child.typ.static_size
+ ret.max_copy += child.typ.static_size
+ return WalkCmd.KEEP_GOING, None
+
+ walk(typ, handle)
+ assert ret.min_size == typ.min_size(version)
+ assert ret.max_size == typ.max_size(version)
+ return ret
+
+
+# Generate .h ##################################################################
+
+
+def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
global _ifdef_stack
_ifdef_stack = []
ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
#ifndef _LIB9P_9P_H_
- #error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
+\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
#endif
#include <stdint.h> /* for uint{{n}}_t types */
+
+#include <libhw/generic/net.h> /* for struct iovec */
"""
+ id2typ: dict[int, idl.Message] = {}
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
+ id2typ[msg.msgid] = msg
+
ret += f"""
-/* versions *******************************************************************/
+/* config *********************************************************************/
+
+#include "config.h"
+"""
+ for ver in sorted(versions):
+ ret += "\n"
+ ret += f"#ifndef {c_ver_ifdef({ver})}\n"
+ ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n"
+ if ver == "9P2000.e": # SPECIAL (9P2000.e)
+ ret += "#else\n"
+ ret += f"\t#if {c_ver_ifdef({ver})}\n"
+ ret += "\t\t#ifndef(CONFIG_9P_MAX_9P2000_e_WELEM)\n"
+ ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
+ ret += "\t\t#endif\n"
+ ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
+ ret += "\t#endif\n"
+ ret += "#endif\n"
+
+ ret += f"""
+/* enum version ***************************************************************/
enum {idprefix}version {{
"""
@@ -559,150 +406,294 @@ enum {idprefix}version {{
ret += ifdef_pop(0)
ret += f"\t{c_ver_enum('NUM')},\n"
ret += "};\n"
- ret += "\n"
- ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"
ret += """
-/* non-message types **********************************************************/
+/* enum msg_type **************************************************************/
+
"""
- for typ in [typ for typ in typs if not isinstance(typ, Message)]:
+ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
+ namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message))
+ for n in range(0x100):
+ if n not in id2typ:
+ continue
+ msg = id2typ[n]
+ ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
+ ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
+ ret += ifdef_pop(0)
+ ret += "};\n"
+
+ ret += """
+/* payload types **************************************************************/
+"""
+
+ def per_version_comment(
+ typ: idl.Type, fn: typing.Callable[[idl.Type, str], str]
+ ) -> str:
+ lines: dict[str, str] = {}
+ for version in sorted(typ.in_versions):
+ lines[version] = fn(typ, version)
+ if len(set(lines.values())) == 1:
+ for _, line in lines.items():
+ return f"/* {line} */\n"
+ assert False
+ else:
+ ret = ""
+ v_width = max(len(c_ver_enum(v)) for v in typ.in_versions)
+ for version, line in lines.items():
+ ret += f"/* {c_ver_enum(version).ljust(v_width)}: {line} */\n"
+ return ret
+
+ for typ in topo_sorted(typs):
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
+
+ def sum_size(typ: idl.Type, version: str) -> str:
+ sz = get_buffer_size(typ, version)
+ assert (
+ sz.min_size <= sz.exp_size
+ and sz.exp_size <= sz.max_size
+ and sz.max_size < u64max
+ )
+ ret = ""
+ if sz.min_size == sz.max_size:
+ ret += f"size = {sz.min_size:,}"
+ else:
+ ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}"
+ if sz.max_size > u32max:
+ ret += " (warning: >UINT32_MAX)"
+ ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}"
+ return ret
+
+ ret += per_version_comment(typ, sum_size)
+
match typ:
- case Number():
+ case idl.Number():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
- case Bitfield():
+ prefix = f"{idprefix.upper()}{typ.name.upper()}_"
+ namewidth = max(len(name) for name in typ.vals)
+ for name, val in typ.vals.items():
+ ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
+ case idl.Bitfield():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
names = [
- *reversed(
- [typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))]
- ),
- "",
- *[k for k in typ.names if k not in typ.bits],
+ typ.bits[n] or f" {n}" for n in reversed(range(0, len(typ.bits)))
]
- namewidth = max(len(name) for name in names)
+ if aliases := [k for k in typ.names if k not in typ.bits]:
+ names.append("")
+ names.extend(aliases)
+ prefix = f"{idprefix.upper()}{typ.name.upper()}_"
+ namewidth = max(len(add_prefix(prefix, name)) for name in names)
ret += "\n"
for name in names:
if name == "":
ret += "\n"
- elif name.startswith(" "):
- ret += ifdef_push(2, c_ver_ifdef(typ.in_versions))
- sp = " " * (
- len("# define ")
- + len(idprefix)
- + len(typ.name)
- + 1
- + namewidth
- + 2
- - len("/* unused")
- )
- ret += f"/* unused{sp}(({c_typename(typ)})(1<<{name[1:]})) */\n"
+ continue
+
+ if name.startswith(" "):
+ vers = typ.in_versions
+ c_name = ""
+ c_val = f"1<<{name[1:]}"
else:
- ret += ifdef_push(2, c_ver_ifdef(typ.names[name].in_versions))
- if name.startswith("_"):
- c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}"
- else:
- c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}"
- sp1 = " " if _ifdef_stack[-1] else ""
- sp2 = " " if _ifdef_stack[-1] else " "
- sp3 = " " * (2 + namewidth - len(name))
- ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(typ)})({typ.names[name].val}))\n"
+ vers = typ.names[name].in_versions
+ c_name = add_prefix(prefix, name)
+ c_val = f"{typ.names[name].val}"
+
+ ret += ifdef_push(2, c_ver_ifdef(vers))
+
+ # It is important all of the `beg` strings have
+ # the same length.
+ end = ""
+ if name.startswith(" "):
+ beg = "/* unused"
+ end = " */"
+ elif _ifdef_stack[-1]:
+ beg = "# define"
+ else:
+ beg = "#define "
+
+ ret += f"{beg} {c_name.ljust(namewidth)} (({c_typename(typ)})({c_val})){end}\n"
ret += ifdef_pop(1)
- case Struct():
- typewidth = max(len(c_typename(m.typ)) for m in typ.members)
+ case idl.Struct(): # and idl.Message():
+ ret += c_typename(typ) + " {"
+ if not typ.members:
+ ret += "};\n"
+ continue
+ ret += "\n"
+
+ typewidth = max(len(c_typename(m.typ, m)) for m in typ.members)
- ret += c_typename(typ) + " {\n"
for member in typ.members:
if member.val:
continue
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- c_type = c_typename(member.typ)
- if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL
- c_type = "char"
- ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
ret += ifdef_pop(1)
ret += "};\n"
ret += ifdef_pop(0)
ret += """
-/* messages *******************************************************************/
-
+/* containers *****************************************************************/
"""
- ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
- namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message))
- for msg in [msg for msg in typs if isinstance(msg, Message)]:
- ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
- ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
- ret += ifdef_pop(0)
- ret += "};\n"
+ ret += "\n"
+ ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n"
- for msg in [msg for msg in typs if isinstance(msg, Message)]:
- ret += "\n"
- ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
- ret += c_typename(msg) + " {"
- if not msg.members:
- ret += "};\n"
+ tmsg_max_iov: dict[str, int] = {}
+ tmsg_max_copy: dict[str, int] = {}
+ rmsg_max_iov: dict[str, int] = {}
+ rmsg_max_copy: dict[str, int] = {}
+ for typ in typs:
+ if not isinstance(typ, idl.Message):
+ continue
+ if typ.name in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e)
continue
+ max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov
+ max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy
+ for version in typ.in_versions:
+ if version not in max_iov:
+ max_iov[version] = 0
+ max_copy[version] = 0
+ sz = get_buffer_size(typ, version)
+ if sz.max_iov > max_iov[version]:
+ max_iov[version] = sz.max_iov
+ if sz.max_copy > max_copy[version]:
+ max_copy[version] = sz.max_copy
+
+ for name, table in [
+ ("tmsg_max_iov", tmsg_max_iov),
+ ("tmsg_max_copy", tmsg_max_copy),
+ ("rmsg_max_iov", rmsg_max_iov),
+ ("rmsg_max_copy", rmsg_max_copy),
+ ]:
+ inv: dict[int, set[str]] = {}
+ for version, maxval in table.items():
+ if maxval not in inv:
+ inv[maxval] = set()
+ inv[maxval].add(version)
+
ret += "\n"
+ directive = "if"
+ seen_e = False # SPECIAL (9P2000.e)
+ for maxval in sorted(inv, reverse=True):
+ ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n"
+ indent = 1
+ if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
+ typ = next(typ for typ in typs if typ.name == "Tswrite")
+ sz = get_buffer_size(typ, "9P2000.e")
+ match name:
+ case "tmsg_max_iov":
+ maxexpr = f"{sz.max_iov}{sz.max_iov_extra}"
+ case "tmsg_max_copy":
+ maxexpr = f"{sz.max_copy}{sz.max_copy_extra}"
+ case _:
+ assert False
+ ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n"
+ ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n"
+ ret += f"\t#else\n"
+ indent += 1
+ ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n"
+ if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e)
+ ret += "\t#endif\n"
+ if "9P2000.e" in inv[maxval]:
+ seen_e = True
+ directive = "elif"
+ ret += "#endif\n"
- typewidth = max(len(c_typename(m.typ)) for m in msg.members)
+ ret += "\n"
+ ret += f"struct {idprefix}Tmsg_send_buf {{\n"
+ ret += f"\tsize_t iov_cnt;\n"
+ ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n"
+ ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n"
+ ret += "};\n"
- for member in msg.members:
- if member.val:
- continue
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
- ret += ifdef_pop(1)
- ret += "};\n"
- ret += ifdef_pop(0)
+ ret += "\n"
+ ret += f"struct {idprefix}Rmsg_send_buf {{\n"
+ ret += f"\tsize_t iov_cnt;\n"
+ ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n"
+ ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n"
+ ret += "};\n"
return ret
-def c_expr(expr: Expr) -> str:
- ret: list[str] = []
- for tok in expr.tokens:
- match tok:
- case ExprOp():
- ret += [tok.op]
- case ExprLit():
- ret += [str(tok.val)]
- case ExprSym(name="end"):
- ret += ["ctx->net_offset"]
- case ExprSym():
- ret += [f"_{tok.name[1:]}_offset"]
- return " ".join(ret)
+# Generate .c ##################################################################
-def gen_c(versions: set[str], typs: list[Type]) -> str:
+def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
global _ifdef_stack
_ifdef_stack = []
ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
-#include <assert.h>
#include <stdbool.h>
#include <stddef.h> /* for size_t */
#include <inttypes.h> /* for PRI* macros */
#include <string.h> /* for memset() */
+#include <libmisc/assert.h>
+
#include <lib9p/9p.h>
#include "internal.h"
"""
+ # utilities ################################################################
+ ret += f"""
+/* utilities ******************************************************************/
+"""
+
def used(arg: str) -> str:
return arg
def unused(arg: str) -> str:
- return f"UNUSED({arg})"
+ return f"LM_UNUSED({arg})"
+
+ id2typ: dict[int, idl.Message] = {}
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
+ id2typ[msg.msgid] = msg
+
+ def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
+ ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
+ for ver in ["unknown", *sorted(versions)]:
+ if ver != "unknown":
+ ret += ifdef_push(1, c_ver_ifdef({ver}))
+ ret += f"\t[{c_ver_enum(ver)}] = {{\n"
+ for n in range(*rng):
+ xmsg: idl.Message | None = id2typ.get(n, None)
+ if xmsg:
+ if ver == "unknown": # SPECIAL (initialization)
+ if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
+ xmsg = None
+ else:
+ if ver not in xmsg.in_versions:
+ xmsg = None
+ if xmsg:
+ ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n"
+ ret += "\t},\n"
+ ret += ifdef_pop(0)
+ ret += "};\n"
+ return ret
+
+ for v in sorted(versions):
+ ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
+ ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n"
+ ret += "#else\n"
+ ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
+ ret += "#endif\n"
+ ret += "\n"
+ ret += "/**\n"
+ ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n"
+ ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n"
+ ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n"
+ ret += " * several version checks together.\n"
+ ret += " */\n"
+ ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n"
# strings ##################################################################
ret += f"""
/* strings ********************************************************************/
-static const char *version_strs[{c_ver_enum('NUM')}] = {{
+const char *_{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{
"""
for ver in ["unknown", *sorted(versions)]:
if ver in versions:
@@ -710,122 +701,115 @@ static const char *version_strs[{c_ver_enum('NUM')}] = {{
ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
ret += ifdef_pop(0)
ret += "};\n"
+
+ ret += "\n"
+ ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
+ ret += msg_table("msg", "name", "char *", (0, 0x100, 1))
+
+ # bitmasks #################################################################
ret += f"""
-const char *{idprefix}version_str(enum {idprefix}version ver) {{
- assert(0 <= ver && ver < {c_ver_enum('NUM')});
- return version_strs[ver];
-}}
+/* bitmasks *******************************************************************/
"""
+ for typ in typs:
+ if not isinstance(typ, idl.Bitfield):
+ continue
+ ret += "\n"
+ ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
+ ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
+ verwidth = max(len(ver) for ver in versions)
+ for ver in sorted(versions):
+ ret += ifdef_push(2, c_ver_ifdef({ver}))
+ ret += (
+ f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
+ + "".join(
+ "1" if typ.bit_is_valid(bitname, ver) else "0"
+ for bitname in reversed(typ.bits)
+ )
+ + ",\n"
+ )
+ ret += ifdef_pop(1)
+ ret += "};\n"
+ ret += ifdef_pop(0)
# validate_* ###############################################################
ret += """
/* validate_* *****************************************************************/
-static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
- if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
- /* If needed-net-size overflowed uint32_t, then
- * there's no way that actual-net-size will live up to
- * that. */
- return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
- if (ctx->net_offset > ctx->net_size)
- return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
- return false;
+LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
+\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
+\t\t/* If needed-net-size overflowed uint32_t, then
+\t\t * there's no way that actual-net-size will live up to
+\t\t * that. */
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\tif (ctx->net_offset > ctx->net_size)
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\treturn false;
}
-static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
- if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
- /* If needed-host-size overflowed size_t, then there's
- * no way that actual-net-size will live up to
- * that. */
- return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
- return false;
+LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
+\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
+\t\t/* If needed-host-size overflowed size_t, then there's
+\t\t * no way that actual-net-size will live up to
+\t\t * that. */
+\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
+\treturn false;
}
-static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
+LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
size_t cnt,
_validate_fn_t item_fn, size_t item_host_size) {
- for (size_t i = 0; i < cnt; i++)
- if (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
- return true;
- return false;
+\tfor (size_t i = 0; i < cnt; i++)
+\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
+\t\t\treturn true;
+\treturn false;
}
-#define validate_1(ctx) _validate_size_net(ctx, 1)
-#define validate_2(ctx) _validate_size_net(ctx, 2)
-#define validate_4(ctx) _validate_size_net(ctx, 4)
-#define validate_8(ctx) _validate_size_net(ctx, 8)
+LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
+LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
+LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
+LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""
- for typ in typs:
- inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ for typ in topo_sorted(typs):
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
-
- if isinstance(typ, Bitfield):
- ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
- verwidth = max(len(ver) for ver in versions)
- for ver in sorted(versions):
- ret += ifdef_push(2, c_ver_ifdef({ver}))
- ret += (
- f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
- + "".join(
- "1" if typ.bit_is_valid(bitname, ver) else "0"
- for bitname in reversed(typ.bits)
- )
- + ",\n"
- )
- ret += ifdef_pop(1)
- ret += "};\n"
-
- ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"
-
- if typ.name == "d": # SPECIAL
- # Optimize... maybe the compiler could figure out to do
- # this, but let's make it obvious.
- ret += "\tuint32_t base_offset = ctx->net_offset;\n"
- ret += "\tif (validate_4(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
- ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
- ret += "}\n"
- continue
- if typ.name == "s": # SPECIAL
- # Add an extra nul-byte on the host, and validate UTF-8
- # (also, similar optimization to "d").
- ret += "\tuint32_t base_offset = ctx->net_offset;\n"
- ret += "\tif (validate_2(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
- ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
- ret += "\t\treturn true;\n"
- ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
- ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
- ret += "\treturn false;\n"
- ret += "}\n"
- continue
+ ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"
match typ:
- case Number():
+ case idl.Number():
ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
- case Bitfield():
+ case idl.Bitfield():
ret += f"\t if (validate_{typ.static_size}(ctx))\n"
ret += "\t\treturn true;\n"
ret += (
f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
)
- ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ if typ.static_size == 1:
+ ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
+ else:
+ ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
ret += f"\tif (val & ~mask)\n"
ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
- case Struct(): # and Message()
+ case idl.Struct(): # and idl.Message()
if len(typ.members) == 0:
ret += "\treturn false;\n"
ret += "}\n"
continue
+ def should_save_value(member: idl.StructMember) -> bool:
+ nonlocal typ
+ assert isinstance(typ, idl.Struct)
+ return bool(
+ member.max
+ or member.val
+ or any(m.cnt == member for m in typ.members)
+ )
+
# Pass 1 - declare value variables
for member in typ.members:
- if member.max or member.val:
+ if should_save_value(member):
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
ret += f"\t{c_typename(member.typ)} {member.name};\n"
ret += ifdef_pop(1)
@@ -834,50 +818,67 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
mark_offset: set[str] = set()
for member in typ.members:
for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, ExprSym) and tok.name.startswith("&"):
+ if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
if tok.name[1:] not in mark_offset:
ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
mark_offset.add(tok.name[1:])
# Pass 3 - main pass
ret += "\treturn false\n"
- prev_size: int | None = None
for member in typ.members:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
ret += f"\t || "
if member.in_versions != typ.in_versions:
ret += "( " + c_ver_cond(member.in_versions) + " && "
if member.cnt is not None:
- assert prev_size
- ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
+ if member.typ.static_size == 1: # SPECIAL (zerocopy)
+ ret += f"_validate_size_net(ctx, {member.cnt.name})"
+ else:
+ ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
+ if typ.name == "s": # SPECIAL (string)
+ ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})'
else:
- if member.max or member.val:
+ if should_save_value(member):
ret += "("
if member.name in mark_offset:
ret += f"({{ _{member.name}_offset = ctx->net_offset; "
ret += f"validate_{member.typ.name}(ctx)"
if member.name in mark_offset:
ret += "; })"
- if member.max or member.val:
- bytes = member.static_size
- assert bytes
- bits = bytes * 8
- ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))"
+ if should_save_value(member):
+ nbytes = member.static_size
+ assert nbytes
+ if nbytes == 1:
+ ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
+ else:
+ ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
if member.in_versions != typ.in_versions:
ret += " )"
ret += "\n"
- prev_size = member.static_size
# Pass 4 - validate ,max= and ,val= constraints
for member in typ.members:
+
+ def lookup_sym(sym: str) -> str:
+ match sym:
+ case "end":
+ return "ctx->net_offset"
+ case _:
+ assert sym.startswith("&")
+ return f"_{sym[1:]}_offset"
+
if member.max:
+ assert member.static_size
+ nbits = member.static_size * 8
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
+ ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.name}) > max) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n'
if member.val:
+ assert member.static_size
+ nbits = member.static_size * 8
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
- ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'
+ ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.name}) != exp) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n'
ret += ifdef_pop(1)
ret += "\t ;\n"
@@ -888,38 +889,38 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
ret += """
/* unmarshal_* ****************************************************************/
-static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
- *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 1;
+LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
+\t*out = ctx->net_bytes[ctx->net_offset];
+\tctx->net_offset += 1;
}
-static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
- *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 2;
+LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
+\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]);
+\tctx->net_offset += 2;
}
-static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
- *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 4;
+LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
+\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]);
+\tctx->net_offset += 4;
}
-static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
- *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 8;
+LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
+\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]);
+\tctx->net_offset += 8;
}
"""
- for typ in typs:
- inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ for typ in topo_sorted(typs):
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static {inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
+ ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
match typ:
- case Number():
+ case idl.Number():
ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
- case Bitfield():
+ case idl.Bitfield():
ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
- case Struct():
+ case idl.Struct():
ret += "\tmemset(out, 0, sizeof(*out));\n"
for member in typ.members:
@@ -937,12 +938,15 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out)
if member.in_versions != typ.in_versions:
ret += "{\n"
ret += prefix
- ret += f"out->{member.name} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
- if typ.name in ["d", "s"]: # SPECIAL
- ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
+ if member.typ.static_size == 1: # SPECIAL (string, zerocopy)
+ ret += f"out->{member.name} = (char *)&ctx->net_bytes[ctx->net_offset];\n"
+ ret += (
+ f"{prefix}ctx->net_offset += out->{member.cnt.name};\n"
+ )
else:
+ ret += f"out->{member.name} = ctx->extra;\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n"
ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
if member.in_versions != typ.in_versions:
ret += "\t}\n"
@@ -950,9 +954,6 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out)
ret += (
f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
)
- if typ.name == "s": # SPECIAL
- ret += "\tctx->extra++;\n"
- ret += "\tout->utf8[out->len] = '\\0';\n"
ret += ifdef_pop(1)
ret += "}\n"
ret += ifdef_pop(0)
@@ -961,174 +962,376 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out)
ret += """
/* marshal_* ******************************************************************/
-static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) {
- lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
- (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
- ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
- ctx->ctx->max_msg_size);
- return true;
-}
-
-static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
- if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
- return _marshal_too_large(ctx);
- ctx->net_bytes[ctx->net_offset] = *val;
- ctx->net_offset += 1;
- return false;
-}
+"""
+ ret += c_macro(
+ "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
+ "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n"
+ "\t\tctx->net_iov_cnt++;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n"
+ "\tctx->net_iov_cnt++;\n"
+ )
+ ret += c_macro(
+ "#define MARSHAL_BYTES(ctx, data, len)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n"
+ "\tctx->net_copied_size += len;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n"
+ )
+ ret += c_macro(
+ "#define MARSHAL_U8LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tctx->net_copied[ctx->net_copied_size] = val;\n"
+ "\tctx->net_copied_size += 1;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n"
+ )
+ ret += c_macro(
+ "#define MARSHAL_U16LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
+ "\tctx->net_copied_size += 2;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n"
+ )
+ ret += c_macro(
+ "#define MARSHAL_U32LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
+ "\tctx->net_copied_size += 4;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n"
+ )
+ ret += c_macro(
+ "#define MARSHAL_U64LE(ctx, val)\n"
+ "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
+ "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
+ "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
+ "\tctx->net_copied_size += 8;\n"
+ "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n"
+ )
-static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
- if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
- return _marshal_too_large(ctx);
- encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 2;
- return false;
-}
+ class OffsetExpr:
+ static: int
+ cond: dict[frozenset[str], "OffsetExpr"]
+ rep: list[tuple[Path, "OffsetExpr"]]
+
+ def __init__(self) -> None:
+ self.static = 0
+ self.rep = []
+ self.cond = {}
+
+ def add(self, other: "OffsetExpr") -> None:
+ self.static += other.static
+ self.rep += other.rep
+ for k, v in other.cond.items():
+ if k in self.cond:
+ self.cond[k].add(v)
+ else:
+ self.cond[k] = v
+
+ def gen_c(
+ self,
+ dsttyp: str,
+ dstvar: str,
+ root: str,
+ indent_depth: int,
+ loop_depth: int,
+ ) -> str:
+ oneline: list[str] = []
+ multiline = ""
+ if self.static:
+ oneline.append(str(self.static))
+ for cnt, sub in self.rep:
+ if not sub.cond and not sub.rep:
+ if sub.static == 1:
+ oneline.append(cnt.c_str(root))
+ else:
+ oneline.append(f"({cnt.c_str(root)})*{sub.static}")
+ continue
+ loopvar = chr(ord("i") + loop_depth)
+ multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
+ multiline += sub.gen_c(
+ "", dstvar, root, indent_depth + 1, loop_depth + 1
+ )
+ multiline += f"{'\t'*indent_depth}}}\n"
+ for vers, sub in self.cond.items():
+ multiline += ifdef_push(indent_depth + 1, c_ver_ifdef(vers))
+ multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n"
+ multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
+ multiline += f"{'\t'*indent_depth}}}\n"
+ multiline += ifdef_pop(indent_depth)
+ if dsttyp:
+ if not oneline:
+ oneline.append("0")
+ ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
+ elif oneline:
+ ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
+ ret += multiline
+ return ret
+
+ type OffsetExprRecursion = typing.Callable[[Path], WalkCmd]
+
+ def get_offset_expr(typ: idl.Type, recurse: OffsetExprRecursion) -> OffsetExpr:
+ if not isinstance(typ, idl.Struct):
+ assert typ.static_size
+ ret = OffsetExpr()
+ ret.static = typ.static_size
+ return ret
+
+ stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]]
+
+ def pop_root() -> None:
+ assert False
+
+ def pop_cond() -> None:
+ nonlocal stack
+ key = frozenset(stack[-1][0].elems[-1].in_versions)
+ if key in stack[-2][1].cond:
+ stack[-2][1].cond[key].add(stack[-1][1])
+ else:
+ stack[-2][1].cond[key] = stack[-1][1]
+ stack = stack[:-1]
+
+ def pop_rep() -> None:
+ nonlocal stack
+ member_path = stack[-1][0]
+ member = member_path.elems[-1]
+ assert member.cnt
+ cnt_path = member_path.parent().add(member.cnt)
+ stack[-2][1].rep.append((cnt_path, stack[-1][1]))
+ stack = stack[:-1]
+
+ def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]:
+ nonlocal recurse
+
+ ret = recurse(path)
+ if ret != WalkCmd.KEEP_GOING:
+ return ret, None
+
+ nonlocal stack
+ stack_len = len(stack)
+
+ def pop() -> None:
+ nonlocal stack
+ nonlocal stack_len
+ while len(stack) > stack_len:
+ stack[-1][2]()
+
+ if path.elems:
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ stack.append((path, OffsetExpr(), pop_cond))
+ if child.cnt:
+ stack.append((path, OffsetExpr(), pop_rep))
+ if not isinstance(child.typ, idl.Struct):
+ assert child.typ.static_size
+ stack[-1][1].static += child.typ.static_size
+ return ret, pop
+
+ stack = [(Path(typ), OffsetExpr(), pop_root)]
+ walk(typ, handle)
+ return stack[0][1]
+
+ def go_to_end(path: Path) -> WalkCmd:
+ return WalkCmd.KEEP_GOING
+
+ def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]:
+ def ret(path: Path) -> WalkCmd:
+ if len(path.elems) == 1 and path.elems[0].name == name:
+ return WalkCmd.ABORT
+ return WalkCmd.KEEP_GOING
-static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
- if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
- return true;
- encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 4;
- return false;
-}
+ return ret
-static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
- if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
- return true;
- encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 8;
- return false;
-}
-"""
for typ in typs:
- inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ if not (
+ isinstance(typ, idl.Message) or typ.name == "stat"
+ ): # SPECIAL (include stat)
+ continue
+ assert isinstance(typ, idl.Struct)
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- ret += f"static {inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n"
- match typ:
- case Number():
- ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n"
- case Bitfield():
- ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n"
- ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n"
- case Struct():
- if len(typ.members) == 0:
- ret += "\treturn false;\n"
- ret += "}\n"
- continue
+ ret += f"static bool marshal_{typ.name}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n"
- # Pass 1 - declare offset variables
- mark_offset = set()
- for member in typ.members:
- if member.val:
- if member.name not in mark_offset:
- ret += f"\tuint32_t _{member.name}_offset;\n"
- mark_offset.add(member.name)
- for tok in member.val.tokens:
- if isinstance(tok, ExprSym) and tok.name.startswith("&"):
- if tok.name[1:] not in mark_offset:
- ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
- mark_offset.add(tok.name[1:])
+ # Pass 1 - check size
+ max_size = max(typ.max_size(v) for v in typ.in_versions)
- # Pass 2 - main pass
- ret += "\treturn false\n"
- for member in typ.members:
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += "\t || "
- if member.in_versions != typ.in_versions:
- ret += "( " + c_ver_cond(member.in_versions) + " && "
- if member.name in mark_offset:
- ret += f"({{ _{member.name}_offset = ctx->net_offset; "
- if member.cnt:
- ret += "({ bool err = false;\n"
- ret += f"\t for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n"
- ret += "\t \terr = "
- if typ.name in ["d", "s"]: # SPECIAL
- # Special-case is that we cast from `char` to `uint8_t`.
- ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n"
+ if max_size > u32max: # SPECIAL (9P2000.e)
+ ret += get_offset_expr(typ, go_to_end).gen_c(
+ "uint64_t", "needed_size", "val->", 1, 0
+ )
+ ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n"
+ else:
+ ret += get_offset_expr(typ, go_to_end).gen_c(
+ "uint32_t", "needed_size", "val->", 1, 0
+ )
+ ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
+ if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
+ ret += f'\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%PRIu32)",\n'
+ ret += f'\t\t\t"{typ.name}",\n'
+ ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
+ ret += "\t\t\tctx->ctx->max_msg_size);\n"
+ ret += "\t\treturn true;\n"
+ ret += "\t}\n"
+
+ # Pass 2 - write data
+ ifdef_depth = 1
+ stack: list[tuple[Path, bool]] = [(Path(typ), False)]
+
+ def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]:
+ nonlocal ret
+ nonlocal ifdef_depth
+ nonlocal stack
+ stack_len = len(stack)
+
+ def pop() -> None:
+ nonlocal ret
+ nonlocal ifdef_depth
+ nonlocal stack
+ nonlocal stack_len
+ while len(stack) > stack_len:
+ ret += f"{'\t'*(len(stack)-1)}}}\n"
+ if stack[-1][1]:
+ ifdef_depth -= 1
+ ret += ifdef_pop(ifdef_depth)
+ stack = stack[:-1]
+
+ loopdepth = sum(1 for elem in path.elems if elem.cnt)
+ struct = path.elems[-1].typ if path.elems else path.root
+ if isinstance(struct, idl.Struct):
+ offsets: list[str] = []
+ for member in struct.members:
+ if not member.val:
+ continue
+ for tok in member.val.tokens:
+ if not isinstance(tok, idl.ExprSym):
+ continue
+ if tok.name == "end" or tok.name.startswith("&"):
+ if tok.name not in offsets:
+ offsets.append(tok.name)
+ for name in offsets:
+ name_prefix = "offsetof_" + "".join(
+ m.name + "_" for m in path.elems
+ )
+ if name == "end":
+ if not path.elems:
+ nonlocal max_size
+ if max_size > u32max:
+ ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
+ else:
+ ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n"
+ continue
+ recurse: OffsetExprRecursion = go_to_end
+ else:
+ assert name.startswith("&")
+ name = name[1:]
+ recurse = go_to_tok(name)
+ expr = get_offset_expr(struct, recurse)
+ expr_prefix = path.c_str("val->", loopdepth)
+ if not expr_prefix.endswith(">"):
+ expr_prefix += "."
+ ret += expr.gen_c(
+ "uint32_t",
+ name_prefix + name,
+ expr_prefix,
+ len(stack),
+ loopdepth,
+ )
+ if path.elems:
+ child = path.elems[-1]
+ parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
+ if child.in_versions < parent.in_versions:
+ ret += ifdef_push(ifdef_depth + 1, c_ver_ifdef(child.in_versions))
+ ifdef_depth += 1
+ ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n"
+ stack.append((path, True))
+ if child.cnt:
+ cnt_path = path.parent().add(child.cnt)
+ if child.typ.static_size == 1: # SPECIAL (zerocopy)
+ if path.root.name == "stat": # SPECIAL (stat)
+ ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
else:
- ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n"
- ret += f"\t err; }})"
- elif member.val:
- # Just increment net_offset, don't actually marsha anything (yet).
- assert member.static_size
- ret += (
- f"({{ ctx->net_offset += {member.static_size}; false; }})"
- )
+ ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
+ return WalkCmd.KEEP_GOING, pop
+ loopvar = chr(ord("i") + loopdepth - 1)
+ ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
+ stack.append((path, False))
+ if not isinstance(child.typ, idl.Struct):
+ if child.val:
+
+ def lookup_sym(sym: str) -> str:
+ nonlocal path
+ if sym.startswith("&"):
+ sym = sym[1:]
+ return (
+ "offsetof_"
+ + "".join(m.name + "_" for m in path.elems[:-1])
+ + sym
+ )
+
+ val = c_expr(child.val, lookup_sym)
else:
- ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
- if member.name in mark_offset:
- ret += "; })"
- if member.in_versions != typ.in_versions:
- ret += " )"
- ret += "\n"
+ val = path.c_str("val->")
+ if isinstance(child.typ, idl.Bitfield):
+ val += f" & {child.typ.name}_masks[ctx->ctx->version]"
+ ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
+ return WalkCmd.KEEP_GOING, pop
- # Pass 3 - marshal ,val= members
- for member in typ.members:
- if member.val:
- assert member.static_size
- ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n"
+ walk(typ, handle)
- ret += ifdef_pop(1)
- ret += "\t ;\n"
+ ret += "\treturn false;\n"
ret += "}\n"
ret += ifdef_pop(0)
- # tables / exports #########################################################
- ret += f"""
-/* tables / exports ***********************************************************/
-
-#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\
- .name = #typ, \\
- .basesize = sizeof(struct {idprefix}msg_##typ), \\
- .validate = validate_##typ, \\
- .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\
- .marshal = (_marshal_fn_t)marshal_##typ, \\
- }}
-#define _NONMSG(num) [num] = {{ \\
- .name = #num, \\
- }}
-
-struct _table_version _{idprefix}versions[{c_ver_enum('NUM')}] = {{
+ # function tables ##########################################################
+ ret += """
+/* function tables ************************************************************/
"""
- id2typ: dict[int, Message] = {}
- for msg in [msg for msg in typs if isinstance(msg, Message)]:
- id2typ[msg.msgid] = msg
- for ver in ["unknown", *sorted(versions)]:
- if ver != "unknown":
- ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t[{c_ver_enum(ver)}] = {{ .msgs = {{\n"
-
- for n in range(0, 0x100):
- xmsg: Message | None = id2typ.get(n, None)
- if xmsg:
- if ver == "unknown": # SPECIAL
- if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
- xmsg = None
- else:
- if ver not in xmsg.in_versions:
- xmsg = None
- if xmsg:
- ret += f"\t\t_MSG({xmsg.name}),\n"
- else:
- ret += "\t\t_NONMSG(0x{:02X}),\n".format(n)
- ret += "\t}},\n"
+ ret += "\n"
+ ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n"
+ rerror = next(typ for typ in typs if typ.name == "Rerror")
+ ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization)
+ for ver in sorted(versions):
+ ret += ifdef_push(1, c_ver_ifdef({ver}))
+ ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n"
ret += ifdef_pop(0)
ret += "};\n"
+ ret += "\n"
+ ret += c_macro(
+ f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
+ f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n"
+ f"\t\t.validate = validate_##typ,\n"
+ f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n"
+ f"\t}}\n"
+ )
+ ret += c_macro(
+ f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
+ f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n"
+ f"\t}}\n"
+ )
+ ret += "\n"
+ ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2))
+ ret += "\n"
+ ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2))
+ ret += "\n"
+ ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2))
+ ret += "\n"
+ ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2))
+
ret += f"""
-FLATTEN bool _{idprefix}validate_stat(struct _validate_ctx *ctx) {{
- return validate_stat(ctx);
+LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{
+\treturn validate_stat(ctx);
}}
-FLATTEN void _{idprefix}unmarshal_stat(struct _unmarshal_ctx *ctx, struct lib9p_stat *out) {{
- unmarshal_stat(ctx, out);
+LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{
+\tunmarshal_stat(ctx, out);
}}
-FLATTEN bool _{idprefix}marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val) {{
- return marshal_stat(ctx, val);
+LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{
+\treturn marshal_stat(ctx, val);
}}
"""
@@ -1136,46 +1339,40 @@ FLATTEN bool _{idprefix}marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat
return ret
-################################################################################
+# Main #########################################################################
-class Parser:
- cache: dict[str, tuple[str, list[Type]]] = {}
-
- def parse_file(self, filename: str) -> tuple[str, list[Type]]:
- filename = os.path.normpath(filename)
- if filename not in self.cache:
-
- def get_include(other_filename: str) -> tuple[str, list[Type]]:
- return self.parse_file(os.path.join(filename, "..", other_filename))
+if __name__ == "__main__":
+ import sys
- self.cache[filename] = parse_file(filename, get_include)
- return self.cache[filename]
+ if typing.TYPE_CHECKING:
- def all(self) -> tuple[set[str], list[Type]]:
- ret_versions: set[str] = set()
- ret_typs: dict[str, Type] = {}
- for version, typs in self.cache.values():
- if version in ret_versions:
- raise ValueError(f"duplicate protocol version {repr(version)}")
- ret_versions.add(version)
- for typ in typs:
- if typ.name in ret_typs:
- if typ != ret_typs[typ.name]:
- raise ValueError(f"duplicate type name {repr(typ.name)}")
- else:
- ret_typs[typ.name] = typ
- return ret_versions, list(ret_typs.values())
+ class ANSIColors:
+ MAGENTA = "\x1b[35m"
+ RED = "\x1b[31m"
+ RESET = "\x1b[0m"
-
-if __name__ == "__main__":
- import sys
+ else:
+ from _colorize import ANSIColors # Present in Python 3.13+
if len(sys.argv) < 2:
raise ValueError("requires at least 1 .9p filename")
- parser = Parser()
+ parser = idl.Parser()
for txtname in sys.argv[1:]:
- parser.parse_file(txtname)
+ try:
+ parser.parse_file(txtname)
+ except SyntaxError as e:
+ print(
+ f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}",
+ file=sys.stderr,
+ )
+ assert e.text
+ print(f"\t{e.text}", file=sys.stderr)
+ print(
+ f"\t{ANSIColors.RED}{'~'*len(e.text)}{ANSIColors.RESET}",
+ file=sys.stderr,
+ )
+ sys.exit(2)
versions, typs = parser.all()
outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: