diff options
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 1853 |
1 files changed, 1025 insertions, 828 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index ec42cfd..f2b4f13 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -2,493 +2,108 @@ # lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files # defining 9P protocol variants. # -# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> # SPDX-License-Identifier: AGPL-3.0-or-later import enum +import graphlib import os.path -import re -from abc import ABC, abstractmethod -from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast +import sys +import typing + +sys.path.insert(0, os.path.normpath(os.path.join(__file__, ".."))) + +import idl # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". -# Types ######################################################################## - - -class Primitive(enum.Enum): - u8 = 1 - u16 = 2 - u32 = 4 - u64 = 8 - - @property - def in_versions(self) -> set[str]: - return set() - - @property - def name(self) -> str: - return str(self.value) - - @property - def static_size(self) -> int: - return self.value - - -class Number: - name: str - in_versions: set[str] - - prim: Primitive - - def __init__(self) -> None: - self.in_versions = set() - - @property - def static_size(self) -> int: - return self.prim.static_size - - -class BitfieldVal: - name: str - in_versions: set[str] - - val: str - - def __init__(self) -> None: - self.in_versions = set() - - -class Bitfield: - name: str - in_versions: set[str] - - prim: Primitive - - bits: list[str] # bitnames - names: dict[str, BitfieldVal] # bits *and* aliases - - def __init__(self) -> None: - self.in_versions = set() - self.names = {} - - @property - def static_size(self) -> int: - return self.prim.static_size - - def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool: - """Return whether the given bit is valid in the given protocol - version. - - """ - bitname = self.bits[bit] if isinstance(bit, int) else bit - assert bitname in self.bits - if not bitname: - return False - if bitname.startswith("_"): - return False - if ver and (ver not in self.names[bitname].in_versions): - return False - return True - - -class ExprLit: - val: int - - def __init__(self, val: int) -> None: - self.val = val - - -class ExprSym: - name: str - - def __init__(self, name: str) -> None: - self.name = name - - -class ExprOp: - op: Literal["-", "+"] - - def __init__(self, op: Literal["-", "+"]) -> None: - self.op = op - -class Expr: - tokens: list[ExprLit | ExprSym | ExprOp] +# Utilities #################################################################### - def __init__(self) -> None: - self.tokens = [] - - def __bool__(self) -> bool: - return len(self.tokens) > 0 - - -class StructMember: - # from left-to-right when parsing - cnt: str | None = None - name: str - typ: "Type" - max: Expr - val: Expr - - in_versions: set[str] - - @property - def static_size(self) -> int | None: - if self.cnt: - return None - return self.typ.static_size - - -class Struct: - name: str - in_versions: set[str] - - members: list[StructMember] - - def __init__(self) -> None: - self.in_versions = set() - - @property - def static_size(self) -> int | None: - size = 0 - for member in self.members: - msize = member.static_size - if msize is None: - return None - size += msize - return size - - -class Message(Struct): - @property - def msgid(self) -> int: - assert len(self.members) >= 3 - assert self.members[1].name == "typ" - assert self.members[1].static_size == 1 - assert self.members[1].val - assert len(self.members[1].val.tokens) == 1 - assert isinstance(self.members[1].val.tokens[0], ExprLit) - return self.members[1].val.tokens[0].val - - -Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message -# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13 -T = TypeVar("T", Number, Bitfield, Struct, Message) - -# Parse *.9p ################################################################### - -re_priname = "(?:1|2|4|8)" # primitive names -re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names -re_impname = r"(?:\*|" + re_symname + ")" # names we can import -re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be - -re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be - -re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)" - -re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})" -re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)" - -re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?" - - -def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None: - spec = spec.strip() - - bit: int | None - val: BitfieldVal - if m := re.fullmatch(re_bitspec_bit, spec): - bit = int(m.group("bit")) - name = m.group("name") - - val = BitfieldVal() - val.name = name - val.val = f"1<<{bit}" - val.in_versions.add(ver) - - if bit < 0 or bit >= len(bf.bits): - raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") - if bf.bits[bit]: - raise ValueError(f"{bf.name}: bit {bit} already assigned") - bf.bits[bit] = val.name - elif m := re.fullmatch(re_bitspec_alias, spec): - name = m.group("name") - valstr = m.group("val") - - val = BitfieldVal() - val.name = name - val.val = valstr - val.in_versions.add(ver) - else: - raise SyntaxError(f"invalid bitfield spec {repr(spec)}") - - if val.name in bf.names: - raise ValueError(f"{bf.name}: name {val.name} already assigned") - bf.names[val.name] = val - - -def parse_expr(expr: str) -> Expr: - assert re.fullmatch(re_expr, expr) - ret = Expr() - for tok in re.split("([-+])", expr): - if tok == "-" or tok == "+": - # I, for the life of me, do not understand why I need this - # cast() to keep mypy happy. - ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))] - elif re.fullmatch("[0-9]+", tok): - ret.tokens += [ExprLit(int(tok))] - else: - ret.tokens += [ExprSym(tok)] - return ret - - -def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None: - for spec in specs.split(): - m = re.fullmatch(re_memberspec, spec) - if not m: - raise SyntaxError(f"invalid member spec {repr(spec)}") - - member = StructMember() - member.in_versions = {ver} - - member.name = m.group("name") - if any(x.name == member.name for x in struct.members): - raise ValueError(f"duplicate member name {repr(member.name)}") - - if m.group("typ") not in env: - raise NameError(f"Unknown type {repr(m.group(2))}") - member.typ = env[m.group("typ")] - - if cnt := m.group("cnt"): - if len(struct.members) == 0 or struct.members[-1].name != cnt: - raise ValueError(f"list count must be previous item: {repr(cnt)}") - if not isinstance(struct.members[-1].typ, Primitive): - raise ValueError(f"list count must be an integer type: {repr(cnt)}") - member.cnt = cnt +idprefix = "lib9p_" - if maxstr := m.group("max"): - if (not isinstance(member.typ, Primitive)) or member.cnt: - raise ValueError("',max=' may only be specified on a non-repeated atom") - member.max = parse_expr(maxstr) - else: - member.max = Expr() +u32max = (1 << 32) - 1 +u64max = (1 << 64) - 1 - if valstr := m.group("val"): - if (not isinstance(member.typ, Primitive)) or member.cnt: - raise ValueError("',val=' may only be specified on a non-repeated atom") - member.val = parse_expr(valstr) - else: - member.val = Expr() - - struct.members += [member] - - -def re_string(grpname: str) -> str: - return f'"(?P<{grpname}>[^"]*)"' - - -re_line_version = f"version\\s+{re_string('version')}" -re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)" -re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})" -re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})" -re_line_bitfield_ = ( - f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}" -) -re_line_struct = ( - f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}" -) -re_line_msg = ( - f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}" -) -re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg - - -def parse_file( - filename: str, get_include: Callable[[str], tuple[str, list[Type]]] -) -> tuple[str, list[Type]]: - version: str | None = None - env: dict[str, Type] = { - "1": Primitive.u8, - "2": Primitive.u16, - "4": Primitive.u32, - "8": Primitive.u64, - } - - def get_type(name: str, tc: type[T]) -> T: - nonlocal env - if name not in env: - raise NameError(f"Unknown type {repr(name)}") - ret = env[name] - if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): - raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}") - return ret - with open(filename, "r") as fh: - prev: Type | None = None - for line in fh: - line = line.split("#", 1)[0].rstrip() - if not line: - continue - if m := re.fullmatch(re_line_version, line): - if version: - raise SyntaxError("must have exactly 1 version line") - version = m.group("version") - continue - if not version: - raise SyntaxError("must have exactly 1 version line") - - if m := re.fullmatch(re_line_import, line): - other_version, other_typs = get_include(m.group("file")) - for symname in m.group("syms").split(sep=","): - symname = symname.strip() - for typ in other_typs: - if typ.name == symname or symname == "*": - match typ: - case Primitive(): - pass - case Number(): - typ.in_versions.add(version) - case Bitfield(): - typ.in_versions.add(version) - for val in typ.names.values(): - if other_version in val.in_versions: - val.in_versions.add(version) - case Struct(): # and Message() - typ.in_versions.add(version) - for member in typ.members: - if other_version in member.in_versions: - member.in_versions.add(version) - env[typ.name] = typ - elif m := re.fullmatch(re_line_num, line): - num = Number() - num.name = m.group("name") - num.in_versions.add(version) - - prim = env[m.group("prim")] - assert isinstance(prim, Primitive) - num.prim = prim - - env[num.name] = num - prev = num - elif m := re.fullmatch(re_line_bitfield, line): - bf = Bitfield() - bf.name = m.group("name") - bf.in_versions.add(version) - - prim = env[m.group("prim")] - assert isinstance(prim, Primitive) - bf.prim = prim - - bf.bits = (prim.static_size * 8) * [""] - - env[bf.name] = bf - prev = bf - elif m := re.fullmatch(re_line_bitfield_, line): - bf = get_type(m.group("name"), Bitfield) - parse_bitspec(version, bf, m.group("member")) - - prev = bf - elif m := re.fullmatch(re_line_struct, line): - match m.group("op"): - case "=": - struct = Struct() - struct.name = m.group("name") - struct.in_versions.add(version) - struct.members = [] - parse_members(version, env, struct, m.group("members")) - - env[struct.name] = struct - prev = struct - case "+=": - struct = get_type(m.group("name"), Struct) - parse_members(version, env, struct, m.group("members")) - - prev = struct - elif m := re.fullmatch(re_line_msg, line): - match m.group("op"): - case "=": - msg = Message() - msg.name = m.group("name") - msg.in_versions.add(version) - msg.members = [] - parse_members(version, env, msg, m.group("members")) - - env[msg.name] = msg - prev = msg - case "+=": - msg = get_type(m.group("name"), Message) - parse_members(version, env, msg, m.group("members")) - - prev = msg - elif m := re.fullmatch(re_line_cont, line): - match prev: - case Bitfield(): - parse_bitspec(version, prev, m.group("specs")) - case Struct(): # and Message() - parse_members(version, env, prev, m.group("specs")) - case _: - raise SyntaxError( - "continuation line must come after a bitfield, struct, or msg line" - ) - else: - raise SyntaxError(f"invalid line {repr(line)}") - if not version: - raise SyntaxError("must have exactly 1 version line") - - typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] - - for typ in [typ for typ in typs if isinstance(typ, Struct)]: - valid_syms = ["end", *["&" + m.name for m in typ.members]] - for member in typ.members: - for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, ExprSym) and tok.name not in valid_syms: - raise ValueError( - f"{typ.name}.{member.name}: invalid sym: {tok.name}" - ) +def tab_ljust(s: str, width: int) -> str: + cur = len(s.expandtabs(tabsize=8)) + if cur >= width: + return s + return s + " " * (width - cur) - return version, typs +def add_prefix(p: str, s: str) -> str: + if s.startswith("_"): + return "_" + p + s[1:] + return p + s -# Generate C ################################################################### -idprefix = "lib9p_" +def c_macro(full: str) -> str: + full = full.rstrip() + assert "\n" in full + lines = [l.rstrip() for l in full.split("\n")] + width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1]) + lines = [tab_ljust(l, width) for l in lines] + return " \\\n".join(lines).rstrip() + "\n" def c_ver_enum(ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" -def c_ver_ifdef(versions: set[str]) -> str: +def c_ver_ifdef(versions: typing.Collection[str]) -> str: return " || ".join( - f"defined(CONFIG_9P_ENABLE_{v.replace('.', '_')})" for v in sorted(versions) + f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions) ) -def c_ver_cond(versions: set[str]) -> str: +def c_ver_cond(versions: typing.Collection[str]) -> str: if len(versions) == 1: - return f"(ctx->ctx->version=={c_ver_enum(next(v for v in versions))})" + v = next(v for v in versions) + return f"is_ver(ctx, {v.replace('.', '_')})" return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" -def c_typename(typ: Type) -> str: +def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: match typ: - case Primitive(): + case idl.Primitive(): + if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) + return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" - case Number(): + case idl.Number(): return f"{idprefix}{typ.name}_t" - case Bitfield(): + case idl.Bitfield(): return f"{idprefix}{typ.name}_t" - case Message(): + case idl.Message(): return f"struct {idprefix}msg_{typ.name}" - case Struct(): + case idl.Struct(): return f"struct {idprefix}{typ.name}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") +def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: + ret: list[str] = [] + for tok in expr.tokens: + match tok: + case idl.ExprOp(): + ret.append(tok.op) + case idl.ExprLit(): + ret.append(str(tok.val)) + case idl.ExprSym(name="s32_max"): + ret.append("INT32_MAX") + case idl.ExprSym(name="s64_max"): + ret.append("INT64_MAX") + case idl.ExprSym(): + ret.append(lookup_sym(tok.name)) + case _: + assert False + return " ".join(ret) + + _ifdef_stack: list[str | None] = [] @@ -496,7 +111,7 @@ def ifdef_push(n: int, _newval: str) -> str: # Grow the stack as needed global _ifdef_stack while len(_ifdef_stack) < n: - _ifdef_stack += [None] + _ifdef_stack.append(None) # Set some variables parentval: str | None = None @@ -531,21 +146,253 @@ def ifdef_pop(n: int) -> str: return ret -def gen_h(versions: set[str], typs: list[Type]) -> str: +# topo_sorted() ################################################################ + + +def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]: + ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter() + for typ in typs: + match typ: + case idl.Number(): + ts.add(typ) + case idl.Bitfield(): + ts.add(typ) + case idl.Struct(): # and idl.Message(): + deps = [ + member.typ + for member in typ.members + if not isinstance(member.typ, idl.Primitive) + ] + ts.add(typ, *deps) + return ts.static_order() + + +# walk() ####################################################################### + + +class Path: + root: idl.Type + elems: list[idl.StructMember] + + def __init__( + self, root: idl.Type, elems: list[idl.StructMember] | None = None + ) -> None: + self.root = root + self.elems = elems if elems is not None else [] + + def add(self, elem: idl.StructMember) -> "Path": + return Path(self.root, self.elems + [elem]) + + def parent(self) -> "Path": + return Path(self.root, self.elems[:-1]) + + def c_str(self, base: str, loopdepth: int = 0) -> str: + ret = base + for i, elem in enumerate(self.elems): + if i > 0: + ret += "." + ret += elem.name + if elem.cnt: + ret += f"[{chr(ord('i')+loopdepth)}]" + loopdepth += 1 + return ret + + def __str__(self) -> str: + return self.c_str(self.root.name + "->") + + +class WalkCmd(enum.Enum): + KEEP_GOING = 1 + DONT_RECURSE = 2 + ABORT = 3 + + +type WalkHandler = typing.Callable[ + [Path], tuple[WalkCmd, typing.Callable[[], None] | None] +] + + +def _walk(path: Path, handle: WalkHandler) -> WalkCmd: + typ = path.elems[-1].typ if path.elems else path.root + + ret, atexit = handle(path) + + if isinstance(typ, idl.Struct): + match ret: + case WalkCmd.KEEP_GOING: + for member in typ.members: + if _walk(path.add(member), handle) == WalkCmd.ABORT: + ret = WalkCmd.ABORT + break + case WalkCmd.DONT_RECURSE: + ret = WalkCmd.KEEP_GOING + case WalkCmd.ABORT: + ret = WalkCmd.ABORT + case _: + assert False, f"invalid cmd: {ret}" + + if atexit: + atexit() + return ret + + +def walk(typ: idl.Type, handle: WalkHandler) -> None: + _walk(Path(typ), handle) + + +# get_buffer_size() ############################################################ + + +class BufferSize: + min_size: int # really just here to sanity-check against typ.min_size(version) + exp_size: int # "expected" or max-reasonable size + max_size: int # really just here to sanity-check against typ.max_size(version) + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + _starts_with_copy: bool + _ends_with_copy: bool + + def __init__(self) -> None: + self.min_size = 0 + self.exp_size = 0 + self.max_size = 0 + self.max_copy = 0 + self.max_copy_extra = "" + self.max_iov = 0 + self.max_iov_extra = "" + self._starts_with_copy = False + self._ends_with_copy = False + + +def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: + assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) + + ret = BufferSize() + + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret.min_size = typ.static_size + ret.exp_size = typ.static_size + ret.max_size = typ.static_size + ret.max_copy = typ.static_size + ret.max_iov = 1 + ret._starts_with_copy = True + ret._ends_with_copy = True + return ret + + def handle(path: Path) -> tuple[WalkCmd, None]: + nonlocal ret + if path.elems: + child = path.elems[-1] + if version not in child.in_versions: + return WalkCmd.DONT_RECURSE, None + if child.cnt: + if child.typ.static_size == 1: # SPECIAL (zerocopy) + ret.max_iov += 1 + # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data + ret.exp_size += 27 if child.name == "utf8" else 8192 + ret.max_size += child.max_cnt + ret._ends_with_copy = False + return WalkCmd.DONT_RECURSE, None + sub = get_buffer_size(child.typ, version) + ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM + ret.max_size += sub.max_size * child.max_cnt + if child.name == "wname" and path.root.name in ( + "Tsread", + "Tswrite", + ): # SPECIAL (9P2000.e) + assert ret._ends_with_copy + assert sub._starts_with_copy + assert not sub._ends_with_copy + ret.max_copy_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" + ) + ret.max_iov_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" + ) + ret.max_iov -= 1 + else: + ret.max_copy += sub.max_copy * child.max_cnt + if sub.max_iov == 1 and sub._starts_with_copy: # is purely copy + ret.max_iov += 1 + else: # contains zero-copy segments + ret.max_iov += sub.max_iov * child.max_cnt + if ret._ends_with_copy and sub._starts_with_copy: + # we can merge this one + ret.max_iov -= 1 + if ( + sub._ends_with_copy + and sub._starts_with_copy + and sub.max_iov > 1 + ): + # we can merge these + ret.max_iov -= child.max_cnt - 1 + ret._ends_with_copy = sub._ends_with_copy + return WalkCmd.DONT_RECURSE, None + elif not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + if not ret._ends_with_copy: + if ret.max_size == 0: + ret._starts_with_copy = True + ret.max_iov += 1 + ret._ends_with_copy = True + ret.min_size += child.typ.static_size + ret.exp_size += child.typ.static_size + ret.max_size += child.typ.static_size + ret.max_copy += child.typ.static_size + return WalkCmd.KEEP_GOING, None + + walk(typ, handle) + assert ret.min_size == typ.min_size(version) + assert ret.max_size == typ.max_size(version) + return ret + + +# Generate .h ################################################################## + + +def gen_h(versions: set[str], typs: list[idl.Type]) -> str: global _ifdef_stack _ifdef_stack = [] ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef _LIB9P_9P_H_ - #error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead +\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead #endif #include <stdint.h> /* for uint{{n}}_t types */ + +#include <libhw/generic/net.h> /* for struct iovec */ """ + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + ret += f""" -/* versions *******************************************************************/ +/* config *********************************************************************/ + +#include "config.h" +""" + for ver in sorted(versions): + ret += "\n" + ret += f"#ifndef {c_ver_ifdef({ver})}\n" + ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n" + if ver == "9P2000.e": # SPECIAL (9P2000.e) + ret += "#else\n" + ret += f"\t#if {c_ver_ifdef({ver})}\n" + ret += "\t\t#ifndef(CONFIG_9P_MAX_9P2000_e_WELEM)\n" + ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += "\t\t#endif\n" + ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" + ret += "\t#endif\n" + ret += "#endif\n" + + ret += f""" +/* enum version ***************************************************************/ enum {idprefix}version {{ """ @@ -559,150 +406,294 @@ enum {idprefix}version {{ ret += ifdef_pop(0) ret += f"\t{c_ver_enum('NUM')},\n" ret += "};\n" - ret += "\n" - ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" ret += """ -/* non-message types **********************************************************/ +/* enum msg_type **************************************************************/ + """ - for typ in [typ for typ in typs if not isinstance(typ, Message)]: + ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message)) + for n in range(0x100): + if n not in id2typ: + continue + msg = id2typ[n] + ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) + ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" + ret += ifdef_pop(0) + ret += "};\n" + + ret += """ +/* payload types **************************************************************/ +""" + + def per_version_comment( + typ: idl.Type, fn: typing.Callable[[idl.Type, str], str] + ) -> str: + lines: dict[str, str] = {} + for version in sorted(typ.in_versions): + lines[version] = fn(typ, version) + if len(set(lines.values())) == 1: + for _, line in lines.items(): + return f"/* {line} */\n" + assert False + else: + ret = "" + v_width = max(len(c_ver_enum(v)) for v in typ.in_versions) + for version, line in lines.items(): + ret += f"/* {c_ver_enum(version).ljust(v_width)}: {line} */\n" + return ret + + for typ in topo_sorted(typs): ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + + def sum_size(typ: idl.Type, version: str) -> str: + sz = get_buffer_size(typ, version) + assert ( + sz.min_size <= sz.exp_size + and sz.exp_size <= sz.max_size + and sz.max_size < u64max + ) + ret = "" + if sz.min_size == sz.max_size: + ret += f"size = {sz.min_size:,}" + else: + ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" + if sz.max_size > u32max: + ret += " (warning: >UINT32_MAX)" + ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" + return ret + + ret += per_version_comment(typ, sum_size) + match typ: - case Number(): + case idl.Number(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" - case Bitfield(): + prefix = f"{idprefix.upper()}{typ.name.upper()}_" + namewidth = max(len(name) for name in typ.vals) + for name, val in typ.vals.items(): + ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" + case idl.Bitfield(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" names = [ - *reversed( - [typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))] - ), - "", - *[k for k in typ.names if k not in typ.bits], + typ.bits[n] or f" {n}" for n in reversed(range(0, len(typ.bits))) ] - namewidth = max(len(name) for name in names) + if aliases := [k for k in typ.names if k not in typ.bits]: + names.append("") + names.extend(aliases) + prefix = f"{idprefix.upper()}{typ.name.upper()}_" + namewidth = max(len(add_prefix(prefix, name)) for name in names) ret += "\n" for name in names: if name == "": ret += "\n" - elif name.startswith(" "): - ret += ifdef_push(2, c_ver_ifdef(typ.in_versions)) - sp = " " * ( - len("# define ") - + len(idprefix) - + len(typ.name) - + 1 - + namewidth - + 2 - - len("/* unused") - ) - ret += f"/* unused{sp}(({c_typename(typ)})(1<<{name[1:]})) */\n" + continue + + if name.startswith(" "): + vers = typ.in_versions + c_name = "" + c_val = f"1<<{name[1:]}" else: - ret += ifdef_push(2, c_ver_ifdef(typ.names[name].in_versions)) - if name.startswith("_"): - c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}" - else: - c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}" - sp1 = " " if _ifdef_stack[-1] else "" - sp2 = " " if _ifdef_stack[-1] else " " - sp3 = " " * (2 + namewidth - len(name)) - ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(typ)})({typ.names[name].val}))\n" + vers = typ.names[name].in_versions + c_name = add_prefix(prefix, name) + c_val = f"{typ.names[name].val}" + + ret += ifdef_push(2, c_ver_ifdef(vers)) + + # It is important all of the `beg` strings have + # the same length. + end = "" + if name.startswith(" "): + beg = "/* unused" + end = " */" + elif _ifdef_stack[-1]: + beg = "# define" + else: + beg = "#define " + + ret += f"{beg} {c_name.ljust(namewidth)} (({c_typename(typ)})({c_val})){end}\n" ret += ifdef_pop(1) - case Struct(): - typewidth = max(len(c_typename(m.typ)) for m in typ.members) + case idl.Struct(): # and idl.Message(): + ret += c_typename(typ) + " {" + if not typ.members: + ret += "};\n" + continue + ret += "\n" + + typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) - ret += c_typename(typ) + " {\n" for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - c_type = c_typename(member.typ) - if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL - c_type = "char" - ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) ret += """ -/* messages *******************************************************************/ - +/* containers *****************************************************************/ """ - ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" - namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message)) - for msg in [msg for msg in typs if isinstance(msg, Message)]: - ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" - ret += ifdef_pop(0) - ret += "};\n" + ret += "\n" + ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n" - for msg in [msg for msg in typs if isinstance(msg, Message)]: - ret += "\n" - ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += c_typename(msg) + " {" - if not msg.members: - ret += "};\n" + tmsg_max_iov: dict[str, int] = {} + tmsg_max_copy: dict[str, int] = {} + rmsg_max_iov: dict[str, int] = {} + rmsg_max_copy: dict[str, int] = {} + for typ in typs: + if not isinstance(typ, idl.Message): + continue + if typ.name in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) continue + max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov + max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy + for version in typ.in_versions: + if version not in max_iov: + max_iov[version] = 0 + max_copy[version] = 0 + sz = get_buffer_size(typ, version) + if sz.max_iov > max_iov[version]: + max_iov[version] = sz.max_iov + if sz.max_copy > max_copy[version]: + max_copy[version] = sz.max_copy + + for name, table in [ + ("tmsg_max_iov", tmsg_max_iov), + ("tmsg_max_copy", tmsg_max_copy), + ("rmsg_max_iov", rmsg_max_iov), + ("rmsg_max_copy", rmsg_max_copy), + ]: + inv: dict[int, set[str]] = {} + for version, maxval in table.items(): + if maxval not in inv: + inv[maxval] = set() + inv[maxval].add(version) + ret += "\n" + directive = "if" + seen_e = False # SPECIAL (9P2000.e) + for maxval in sorted(inv, reverse=True): + ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" + indent = 1 + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + typ = next(typ for typ in typs if typ.name == "Tswrite") + sz = get_buffer_size(typ, "9P2000.e") + match name: + case "tmsg_max_iov": + maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" + case "tmsg_max_copy": + maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" + case _: + assert False + ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" + ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" + ret += f"\t#else\n" + indent += 1 + ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + ret += "\t#endif\n" + if "9P2000.e" in inv[maxval]: + seen_e = True + directive = "elif" + ret += "#endif\n" - typewidth = max(len(c_typename(m.typ)) for m in msg.members) + ret += "\n" + ret += f"struct {idprefix}Tmsg_send_buf {{\n" + ret += f"\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" + ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" + ret += "};\n" - for member in msg.members: - if member.val: - continue - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" - ret += ifdef_pop(1) - ret += "};\n" - ret += ifdef_pop(0) + ret += "\n" + ret += f"struct {idprefix}Rmsg_send_buf {{\n" + ret += f"\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" + ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" + ret += "};\n" return ret -def c_expr(expr: Expr) -> str: - ret: list[str] = [] - for tok in expr.tokens: - match tok: - case ExprOp(): - ret += [tok.op] - case ExprLit(): - ret += [str(tok.val)] - case ExprSym(name="end"): - ret += ["ctx->net_offset"] - case ExprSym(): - ret += [f"_{tok.name[1:]}_offset"] - return " ".join(ret) +# Generate .c ################################################################## -def gen_c(versions: set[str], typs: list[Type]) -> str: +def gen_c(versions: set[str], typs: list[idl.Type]) -> str: global _ifdef_stack _ifdef_stack = [] ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ -#include <assert.h> #include <stdbool.h> #include <stddef.h> /* for size_t */ #include <inttypes.h> /* for PRI* macros */ #include <string.h> /* for memset() */ +#include <libmisc/assert.h> + #include <lib9p/9p.h> #include "internal.h" """ + # utilities ################################################################ + ret += f""" +/* utilities ******************************************************************/ +""" + def used(arg: str) -> str: return arg def unused(arg: str) -> str: - return f"UNUSED({arg})" + return f"LM_UNUSED({arg})" + + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + + def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: + ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" + for ver in ["unknown", *sorted(versions)]: + if ver != "unknown": + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {{\n" + for n in range(*rng): + xmsg: idl.Message | None = id2typ.get(n, None) + if xmsg: + if ver == "unknown": # SPECIAL (initialization) + if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: + xmsg = None + else: + if ver not in xmsg.in_versions: + xmsg = None + if xmsg: + ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n" + ret += "\t},\n" + ret += ifdef_pop(0) + ret += "};\n" + return ret + + for v in sorted(versions): + ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" + ret += "#else\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" + ret += "#endif\n" + ret += "\n" + ret += "/**\n" + ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n" + ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n" + ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" + ret += " * several version checks together.\n" + ret += " */\n" + ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" # strings ################################################################## ret += f""" /* strings ********************************************************************/ -static const char *version_strs[{c_ver_enum('NUM')}] = {{ +const char *_{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: @@ -710,122 +701,115 @@ static const char *version_strs[{c_ver_enum('NUM')}] = {{ ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' ret += ifdef_pop(0) ret += "};\n" + + ret += "\n" + ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" + ret += msg_table("msg", "name", "char *", (0, 0x100, 1)) + + # bitmasks ################################################################# ret += f""" -const char *{idprefix}version_str(enum {idprefix}version ver) {{ - assert(0 <= ver && ver < {c_ver_enum('NUM')}); - return version_strs[ver]; -}} +/* bitmasks *******************************************************************/ """ + for typ in typs: + if not isinstance(typ, idl.Bitfield): + continue + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n" + verwidth = max(len(ver) for ver in versions) + for ver in sorted(versions): + ret += ifdef_push(2, c_ver_ifdef({ver})) + ret += ( + f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + + "".join( + "1" if typ.bit_is_valid(bitname, ver) else "0" + for bitname in reversed(typ.bits) + ) + + ",\n" + ) + ret += ifdef_pop(1) + ret += "};\n" + ret += ifdef_pop(0) # validate_* ############################################################### ret += """ /* validate_* *****************************************************************/ -static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { - if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) - /* If needed-net-size overflowed uint32_t, then - * there's no way that actual-net-size will live up to - * that. */ - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); - if (ctx->net_offset > ctx->net_size) - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); - return false; +LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { +\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) +\t\t/* If needed-net-size overflowed uint32_t, then +\t\t * there's no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\tif (ctx->net_offset > ctx->net_size) +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; } -static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { - if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) - /* If needed-host-size overflowed size_t, then there's - * no way that actual-net-size will live up to - * that. */ - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); - return false; +LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { +\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) +\t\t/* If needed-host-size overflowed size_t, then there's +\t\t * no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; } -static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, +LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { - for (size_t i = 0; i < cnt; i++) - if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) - return true; - return false; +\tfor (size_t i = 0; i < cnt; i++) +\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) +\t\t\treturn true; +\treturn false; } -#define validate_1(ctx) _validate_size_net(ctx, 1) -#define validate_2(ctx) _validate_size_net(ctx, 2) -#define validate_4(ctx) _validate_size_net(ctx, 4) -#define validate_8(ctx) _validate_size_net(ctx, 8) +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ - for typ in typs: - inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" - argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + for typ in topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - - if isinstance(typ, Bitfield): - ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n" - verwidth = max(len(ver) for ver in versions) - for ver in sorted(versions): - ret += ifdef_push(2, c_ver_ifdef({ver})) - ret += ( - f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" - + "".join( - "1" if typ.bit_is_valid(bitname, ver) else "0" - for bitname in reversed(typ.bits) - ) - + ",\n" - ) - ret += ifdef_pop(1) - ret += "};\n" - - ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" - - if typ.name == "d": # SPECIAL - # Optimize... maybe the compiler could figure out to do - # this, but let's make it obvious. - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_4(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" - ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" - ret += "}\n" - continue - if typ.name == "s": # SPECIAL - # Add an extra nul-byte on the host, and validate UTF-8 - # (also, similar optimization to "d"). - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_2(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" - ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" - ret += "\t\treturn true;\n" - ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" - ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' - ret += "\treturn false;\n" - ret += "}\n" - continue + ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: - case Number(): + case idl.Number(): ret += f"\treturn validate_{typ.prim.name}(ctx);\n" - case Bitfield(): + case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" ret += ( f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n" ) - ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + if typ.static_size == 1: + ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + else: + ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += f"\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" - case Struct(): # and Message() + case idl.Struct(): # and idl.Message() if len(typ.members) == 0: ret += "\treturn false;\n" ret += "}\n" continue + def should_save_value(member: idl.StructMember) -> bool: + nonlocal typ + assert isinstance(typ, idl.Struct) + return bool( + member.max + or member.val + or any(m.cnt == member for m in typ.members) + ) + # Pass 1 - declare value variables for member in typ.members: - if member.max or member.val: + if should_save_value(member): ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ)} {member.name};\n" ret += ifdef_pop(1) @@ -834,50 +818,67 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, ExprSym) and tok.name.startswith("&"): + if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"): if tok.name[1:] not in mark_offset: ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" mark_offset.add(tok.name[1:]) # Pass 3 - main pass ret += "\treturn false\n" - prev_size: int | None = None for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: - assert prev_size - ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + if member.typ.static_size == 1: # SPECIAL (zerocopy) + ret += f"_validate_size_net(ctx, {member.cnt.name})" + else: + ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" + if typ.name == "s": # SPECIAL (string) + ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' else: - if member.max or member.val: + if should_save_value(member): ret += "(" if member.name in mark_offset: ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" if member.name in mark_offset: ret += "; })" - if member.max or member.val: - bytes = member.static_size - assert bytes - bits = bytes * 8 - ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" + if should_save_value(member): + nbytes = member.static_size + assert nbytes + if nbytes == 1: + ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + else: + ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" - prev_size = member.static_size # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: + + def lookup_sym(sym: str) -> str: + match sym: + case "end": + return "ctx->net_offset" + case _: + assert sym.startswith("&") + return f"_{sym[1:]}_offset" + if member.max: + assert member.static_size + nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n' + ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.name}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n' if member.val: + assert member.static_size + nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n' + ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.name}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n' ret += ifdef_pop(1) ret += "\t ;\n" @@ -888,38 +889,38 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, ret += """ /* unmarshal_* ****************************************************************/ -static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { - *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 1; +LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { +\t*out = ctx->net_bytes[ctx->net_offset]; +\tctx->net_offset += 1; } -static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { - *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 2; +LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { +\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 2; } -static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { - *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 4; +LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { +\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 4; } -static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { - *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 8; +LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { +\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 8; } """ - for typ in typs: - inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" - argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + for typ in topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static {inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" match typ: - case Number(): + case idl.Number(): ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" - case Bitfield(): + case idl.Bitfield(): ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" - case Struct(): + case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" for member in typ.members: @@ -937,12 +938,15 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) if member.in_versions != typ.in_versions: ret += "{\n" ret += prefix - ret += f"out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" - ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" - if typ.name in ["d", "s"]: # SPECIAL - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" + if member.typ.static_size == 1: # SPECIAL (string, zerocopy) + ret += f"out->{member.name} = (char *)&ctx->net_bytes[ctx->net_offset];\n" + ret += ( + f"{prefix}ctx->net_offset += out->{member.cnt.name};\n" + ) else: + ret += f"out->{member.name} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" if member.in_versions != typ.in_versions: ret += "\t}\n" @@ -950,9 +954,6 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) ret += ( f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" ) - if typ.name == "s": # SPECIAL - ret += "\tctx->extra++;\n" - ret += "\tout->utf8[out->len] = '\\0';\n" ret += ifdef_pop(1) ret += "}\n" ret += ifdef_pop(0) @@ -961,174 +962,376 @@ static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) ret += """ /* marshal_* ******************************************************************/ -static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { - lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", - (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", - ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), - ctx->ctx->max_msg_size); - return true; -} - -static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { - if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) - return _marshal_too_large(ctx); - ctx->net_bytes[ctx->net_offset] = *val; - ctx->net_offset += 1; - return false; -} +""" + ret += c_macro( + "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" + "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" + "\t\tctx->net_iov_cnt++;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" + "\tctx->net_iov_cnt++;\n" + ) + ret += c_macro( + "#define MARSHAL_BYTES(ctx, data, len)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" + "\tctx->net_copied_size += len;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" + ) + ret += c_macro( + "#define MARSHAL_U8LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tctx->net_copied[ctx->net_copied_size] = val;\n" + "\tctx->net_copied_size += 1;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" + ) + ret += c_macro( + "#define MARSHAL_U16LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 2;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" + ) + ret += c_macro( + "#define MARSHAL_U32LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 4;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" + ) + ret += c_macro( + "#define MARSHAL_U64LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 8;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" + ) -static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { - if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) - return _marshal_too_large(ctx); - encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 2; - return false; -} + class OffsetExpr: + static: int + cond: dict[frozenset[str], "OffsetExpr"] + rep: list[tuple[Path, "OffsetExpr"]] + + def __init__(self) -> None: + self.static = 0 + self.rep = [] + self.cond = {} + + def add(self, other: "OffsetExpr") -> None: + self.static += other.static + self.rep += other.rep + for k, v in other.cond.items(): + if k in self.cond: + self.cond[k].add(v) + else: + self.cond[k] = v + + def gen_c( + self, + dsttyp: str, + dstvar: str, + root: str, + indent_depth: int, + loop_depth: int, + ) -> str: + oneline: list[str] = [] + multiline = "" + if self.static: + oneline.append(str(self.static)) + for cnt, sub in self.rep: + if not sub.cond and not sub.rep: + if sub.static == 1: + oneline.append(cnt.c_str(root)) + else: + oneline.append(f"({cnt.c_str(root)})*{sub.static}") + continue + loopvar = chr(ord("i") + loop_depth) + multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" + multiline += sub.gen_c( + "", dstvar, root, indent_depth + 1, loop_depth + 1 + ) + multiline += f"{'\t'*indent_depth}}}\n" + for vers, sub in self.cond.items(): + multiline += ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) + multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n" + multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) + multiline += f"{'\t'*indent_depth}}}\n" + multiline += ifdef_pop(indent_depth) + if dsttyp: + if not oneline: + oneline.append("0") + ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" + elif oneline: + ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" + ret += multiline + return ret + + type OffsetExprRecursion = typing.Callable[[Path], WalkCmd] + + def get_offset_expr(typ: idl.Type, recurse: OffsetExprRecursion) -> OffsetExpr: + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret = OffsetExpr() + ret.static = typ.static_size + return ret + + stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]] + + def pop_root() -> None: + assert False + + def pop_cond() -> None: + nonlocal stack + key = frozenset(stack[-1][0].elems[-1].in_versions) + if key in stack[-2][1].cond: + stack[-2][1].cond[key].add(stack[-1][1]) + else: + stack[-2][1].cond[key] = stack[-1][1] + stack = stack[:-1] + + def pop_rep() -> None: + nonlocal stack + member_path = stack[-1][0] + member = member_path.elems[-1] + assert member.cnt + cnt_path = member_path.parent().add(member.cnt) + stack[-2][1].rep.append((cnt_path, stack[-1][1])) + stack = stack[:-1] + + def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]: + nonlocal recurse + + ret = recurse(path) + if ret != WalkCmd.KEEP_GOING: + return ret, None + + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + stack[-1][2]() + + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + stack.append((path, OffsetExpr(), pop_cond)) + if child.cnt: + stack.append((path, OffsetExpr(), pop_rep)) + if not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + stack[-1][1].static += child.typ.static_size + return ret, pop + + stack = [(Path(typ), OffsetExpr(), pop_root)] + walk(typ, handle) + return stack[0][1] + + def go_to_end(path: Path) -> WalkCmd: + return WalkCmd.KEEP_GOING + + def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: + def ret(path: Path) -> WalkCmd: + if len(path.elems) == 1 and path.elems[0].name == name: + return WalkCmd.ABORT + return WalkCmd.KEEP_GOING -static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { - if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) - return true; - encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 4; - return false; -} + return ret -static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { - if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) - return true; - encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 8; - return false; -} -""" for typ in typs: - inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" - argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + if not ( + isinstance(typ, idl.Message) or typ.name == "stat" + ): # SPECIAL (include stat) + continue + assert isinstance(typ, idl.Struct) ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static {inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n" - match typ: - case Number(): - ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n" - case Bitfield(): - ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n" - ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n" - case Struct(): - if len(typ.members) == 0: - ret += "\treturn false;\n" - ret += "}\n" - continue + ret += f"static bool marshal_{typ.name}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" - # Pass 1 - declare offset variables - mark_offset = set() - for member in typ.members: - if member.val: - if member.name not in mark_offset: - ret += f"\tuint32_t _{member.name}_offset;\n" - mark_offset.add(member.name) - for tok in member.val.tokens: - if isinstance(tok, ExprSym) and tok.name.startswith("&"): - if tok.name[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" - mark_offset.add(tok.name[1:]) + # Pass 1 - check size + max_size = max(typ.max_size(v) for v in typ.in_versions) - # Pass 2 - main pass - ret += "\treturn false\n" - for member in typ.members: - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += "\t || " - if member.in_versions != typ.in_versions: - ret += "( " + c_ver_cond(member.in_versions) + " && " - if member.name in mark_offset: - ret += f"({{ _{member.name}_offset = ctx->net_offset; " - if member.cnt: - ret += "({ bool err = false;\n" - ret += f"\t for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n" - ret += "\t \terr = " - if typ.name in ["d", "s"]: # SPECIAL - # Special-case is that we cast from `char` to `uint8_t`. - ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n" + if max_size > u32max: # SPECIAL (9P2000.e) + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint64_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" + else: + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint32_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" + if isinstance(typ, idl.Message): # SPECIAL (disable for stat) + ret += f'\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%PRIu32)",\n' + ret += f'\t\t\t"{typ.name}",\n' + ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' + ret += "\t\t\tctx->ctx->max_msg_size);\n" + ret += "\t\treturn true;\n" + ret += "\t}\n" + + # Pass 2 - write data + ifdef_depth = 1 + stack: list[tuple[Path, bool]] = [(Path(typ), False)] + + def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + ret += f"{'\t'*(len(stack)-1)}}}\n" + if stack[-1][1]: + ifdef_depth -= 1 + ret += ifdef_pop(ifdef_depth) + stack = stack[:-1] + + loopdepth = sum(1 for elem in path.elems if elem.cnt) + struct = path.elems[-1].typ if path.elems else path.root + if isinstance(struct, idl.Struct): + offsets: list[str] = [] + for member in struct.members: + if not member.val: + continue + for tok in member.val.tokens: + if not isinstance(tok, idl.ExprSym): + continue + if tok.name == "end" or tok.name.startswith("&"): + if tok.name not in offsets: + offsets.append(tok.name) + for name in offsets: + name_prefix = "offsetof_" + "".join( + m.name + "_" for m in path.elems + ) + if name == "end": + if not path.elems: + nonlocal max_size + if max_size > u32max: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" + else: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" + continue + recurse: OffsetExprRecursion = go_to_end + else: + assert name.startswith("&") + name = name[1:] + recurse = go_to_tok(name) + expr = get_offset_expr(struct, recurse) + expr_prefix = path.c_str("val->", loopdepth) + if not expr_prefix.endswith(">"): + expr_prefix += "." + ret += expr.gen_c( + "uint32_t", + name_prefix + name, + expr_prefix, + len(stack), + loopdepth, + ) + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + ret += ifdef_push(ifdef_depth + 1, c_ver_ifdef(child.in_versions)) + ifdef_depth += 1 + ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n" + stack.append((path, True)) + if child.cnt: + cnt_path = path.parent().add(child.cnt) + if child.typ.static_size == 1: # SPECIAL (zerocopy) + if path.root.name == "stat": # SPECIAL (stat) + ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" else: - ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n" - ret += f"\t err; }})" - elif member.val: - # Just increment net_offset, don't actually marsha anything (yet). - assert member.static_size - ret += ( - f"({{ ctx->net_offset += {member.static_size}; false; }})" - ) + ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + return WalkCmd.KEEP_GOING, pop + loopvar = chr(ord("i") + loopdepth - 1) + ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" + stack.append((path, False)) + if not isinstance(child.typ, idl.Struct): + if child.val: + + def lookup_sym(sym: str) -> str: + nonlocal path + if sym.startswith("&"): + sym = sym[1:] + return ( + "offsetof_" + + "".join(m.name + "_" for m in path.elems[:-1]) + + sym + ) + + val = c_expr(child.val, lookup_sym) else: - ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" - if member.name in mark_offset: - ret += "; })" - if member.in_versions != typ.in_versions: - ret += " )" - ret += "\n" + val = path.c_str("val->") + if isinstance(child.typ, idl.Bitfield): + val += f" & {child.typ.name}_masks[ctx->ctx->version]" + ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" + return WalkCmd.KEEP_GOING, pop - # Pass 3 - marshal ,val= members - for member in typ.members: - if member.val: - assert member.static_size - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n" + walk(typ, handle) - ret += ifdef_pop(1) - ret += "\t ;\n" + ret += "\treturn false;\n" ret += "}\n" ret += ifdef_pop(0) - # tables / exports ######################################################### - ret += f""" -/* tables / exports ***********************************************************/ - -#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ - .name = #typ, \\ - .basesize = sizeof(struct {idprefix}msg_##typ), \\ - .validate = validate_##typ, \\ - .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ - .marshal = (_marshal_fn_t)marshal_##typ, \\ - }} -#define _NONMSG(num) [num] = {{ \\ - .name = #num, \\ - }} - -struct _table_version _{idprefix}versions[{c_ver_enum('NUM')}] = {{ + # function tables ########################################################## + ret += """ +/* function tables ************************************************************/ """ - id2typ: dict[int, Message] = {} - for msg in [msg for msg in typs if isinstance(msg, Message)]: - id2typ[msg.msgid] = msg - for ver in ["unknown", *sorted(versions)]: - if ver != "unknown": - ret += ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t[{c_ver_enum(ver)}] = {{ .msgs = {{\n" - - for n in range(0, 0x100): - xmsg: Message | None = id2typ.get(n, None) - if xmsg: - if ver == "unknown": # SPECIAL - if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: - xmsg = None - else: - if ver not in xmsg.in_versions: - xmsg = None - if xmsg: - ret += f"\t\t_MSG({xmsg.name}),\n" - else: - ret += "\t\t_NONMSG(0x{:02X}),\n".format(n) - ret += "\t}},\n" + ret += "\n" + ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" + rerror = next(typ for typ in typs if typ.name == "Rerror") + ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) + for ver in sorted(versions): + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n" ret += ifdef_pop(0) ret += "};\n" + ret += "\n" + ret += c_macro( + f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" + f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" + f"\t\t.validate = validate_##typ,\n" + f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" + f"\t}}\n" + ) + ret += c_macro( + f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" + f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" + f"\t}}\n" + ) + ret += "\n" + ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)) + ret += "\n" + ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)) + ret += "\n" + ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)) + ret += "\n" + ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)) + ret += f""" -FLATTEN bool _{idprefix}validate_stat(struct _validate_ctx *ctx) {{ - return validate_stat(ctx); +LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{ +\treturn validate_stat(ctx); }} -FLATTEN void _{idprefix}unmarshal_stat(struct _unmarshal_ctx *ctx, struct lib9p_stat *out) {{ - unmarshal_stat(ctx, out); +LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{ +\tunmarshal_stat(ctx, out); }} -FLATTEN bool _{idprefix}marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val) {{ - return marshal_stat(ctx, val); +LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{ +\treturn marshal_stat(ctx, val); }} """ @@ -1136,46 +1339,40 @@ FLATTEN bool _{idprefix}marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat return ret -################################################################################ +# Main ######################################################################### -class Parser: - cache: dict[str, tuple[str, list[Type]]] = {} - - def parse_file(self, filename: str) -> tuple[str, list[Type]]: - filename = os.path.normpath(filename) - if filename not in self.cache: - - def get_include(other_filename: str) -> tuple[str, list[Type]]: - return self.parse_file(os.path.join(filename, "..", other_filename)) +if __name__ == "__main__": + import sys - self.cache[filename] = parse_file(filename, get_include) - return self.cache[filename] + if typing.TYPE_CHECKING: - def all(self) -> tuple[set[str], list[Type]]: - ret_versions: set[str] = set() - ret_typs: dict[str, Type] = {} - for version, typs in self.cache.values(): - if version in ret_versions: - raise ValueError(f"duplicate protocol version {repr(version)}") - ret_versions.add(version) - for typ in typs: - if typ.name in ret_typs: - if typ != ret_typs[typ.name]: - raise ValueError(f"duplicate type name {repr(typ.name)}") - else: - ret_typs[typ.name] = typ - return ret_versions, list(ret_typs.values()) + class ANSIColors: + MAGENTA = "\x1b[35m" + RED = "\x1b[31m" + RESET = "\x1b[0m" - -if __name__ == "__main__": - import sys + else: + from _colorize import ANSIColors # Present in Python 3.13+ if len(sys.argv) < 2: raise ValueError("requires at least 1 .9p filename") - parser = Parser() + parser = idl.Parser() for txtname in sys.argv[1:]: - parser.parse_file(txtname) + try: + parser.parse_file(txtname) + except SyntaxError as e: + print( + f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}", + file=sys.stderr, + ) + assert e.text + print(f"\t{e.text}", file=sys.stderr) + print( + f"\t{ANSIColors.RED}{'~'*len(e.text)}{ANSIColors.RESET}", + file=sys.stderr, + ) + sys.exit(2) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: |