diff options
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-x | lib9p/idl.gen | 1180 |
1 files changed, 1180 insertions, 0 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen new file mode 100755 index 0000000..1dafef9 --- /dev/null +++ b/lib9p/idl.gen @@ -0,0 +1,1180 @@ +#!/usr/bin/env python +# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files +# defining 9P protocol variants. +# +# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-Licence-Identifier: AGPL-3.0-or-later + +import enum +import os.path +import re +from abc import ABC, abstractmethod +from typing import Callable, Literal, TypeAlias, TypeVar + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + +T = TypeVar("T") + +# Types ######################################################################## + +Type: TypeAlias = "Primitive | Number | Bitfield | Struct | Message" + + +class Primitive(enum.Enum): + u8 = 1 + u16 = 2 + u32 = 4 + u64 = 8 + + @property + def in_versions(self) -> set[str]: + return set() + + @property + def name(self) -> str: + return str(self.value) + + @property + def static_size(self) -> int: + return self.value + + +class Number: + name: str + in_versions: set[str] + + prim: Primitive + + def __init__(self) -> None: + self.in_versions = set() + + @property + def static_size(self) -> int: + return self.prim.static_size + + +class BitfieldVal: + name: str + in_versions: set[str] + + val: str + + def __init__(self) -> None: + self.in_versions = set() + + +class Bitfield: + name: str + in_versions: set[str] + + prim: Primitive + + bits: list[str] # bitnames + names: dict[str, BitfieldVal] # bits *and* aliases + + def __init__(self) -> None: + self.in_versions = set() + self.names = {} + + @property + def static_size(self) -> int: + return self.prim.static_size + + def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool: + """Return whether the given bit is valid in the given protocol + version. + + """ + bitname = self.bits[bit] if isinstance(bit, int) else bit + assert bitname in self.bits + if not bitname: + return False + if bitname.startswith("_"): + return False + if ver and (ver not in self.names[bitname].in_versions): + return False + return True + + +class ExprLit: + val: int + + def __init__(self, val: int) -> None: + self.val = val + + +class ExprSym: + name: str + + def __init__(self, name: str) -> None: + self.name = name + + +class ExprOp: + op: Literal["-", "+"] + + def __init__(self, op: Literal["-", "+"]) -> None: + self.op = op + + +class Expr: + tokens: list[ExprLit | ExprSym | ExprOp] + + def __init__(self) -> None: + self.tokens = [] + + def __bool__(self) -> bool: + return len(self.tokens) > 0 + + +class StructMember: + # from left-to-right when parsing + cnt: str | None = None + name: str + typ: Type + max: Expr + val: Expr + + in_versions: set[str] + + @property + def static_size(self) -> int | None: + if self.cnt: + return None + return self.typ.static_size + + +class Struct: + name: str + in_versions: set[str] + + members: list[StructMember] + + def __init__(self) -> None: + self.in_versions = set() + + @property + def static_size(self) -> int | None: + size = 0 + for member in self.members: + msize = member.static_size + if msize is None: + return None + size += msize + return size + + +class Message(Struct): + @property + def msgid(self) -> int: + assert len(self.members) >= 3 + assert self.members[1].name == "typ" + assert self.members[1].static_size == 1 + assert self.members[1].val + assert len(self.members[1].val.tokens) == 1 + assert isinstance(self.members[1].val.tokens[0], ExprLit) + return self.members[1].val.tokens[0].val + + +# Parse *.9p ################################################################### + +re_priname = "(?:1|2|4|8)" # primitive names +re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names +re_impname = r"(?:\*|" + re_symname + ")" # names we can import +re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be + +re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be + +re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)" + +re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})" +re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)" + +re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?" + + +def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None: + spec = spec.strip() + + bit: int | None + val: BitfieldVal + if m := re.fullmatch(re_bitspec_bit, spec): + bit = int(m.group("bit")) + name = m.group("name") + + val = BitfieldVal() + val.name = name + val.val = f"1<<{bit}" + val.in_versions.add(ver) + + if bit < 0 or bit >= len(bf.bits): + raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") + if bf.bits[bit]: + raise ValueError(f"{bf.name}: bit {bit} already assigned") + bf.bits[bit] = val.name + elif m := re.fullmatch(re_bitspec_alias, spec): + name = m.group("name") + valstr = m.group("val") + + val = BitfieldVal() + val.name = name + val.val = valstr + val.in_versions.add(ver) + else: + raise SyntaxError(f"invalid bitfield spec {repr(spec)}") + + if val.name in bf.names: + raise ValueError(f"{bf.name}: name {val.name} already assigned") + bf.names[val.name] = val + + +def parse_expr(expr: str) -> Expr: + assert re.fullmatch(re_expr, expr) + ret = Expr() + for tok in re.split("([-+])", expr): + if tok == "-" or tok == "+": + ret.tokens += [ExprOp(tok)] + elif re.fullmatch("[0-9]+", tok): + ret.tokens += [ExprLit(int(tok))] + else: + ret.tokens += [ExprSym(tok)] + return ret + + +def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None: + for spec in specs.split(): + m = re.fullmatch(re_memberspec, spec) + if not m: + raise SyntaxError(f"invalid member spec {repr(spec)}") + + member = StructMember() + member.in_versions = {ver} + + member.name = m.group("name") + if any(x.name == member.name for x in struct.members): + raise ValueError(f"duplicate member name {repr(member.name)}") + + if m.group("typ") not in env: + raise NameError(f"Unknown type {repr(m.group(2))}") + member.typ = env[m.group("typ")] + + if cnt := m.group("cnt"): + if len(struct.members) == 0 or struct.members[-1].name != cnt: + raise ValueError(f"list count must be previous item: {repr(cnt)}") + if not isinstance(struct.members[-1].typ, Primitive): + raise ValueError(f"list count must be an integer type: {repr(cnt)}") + member.cnt = cnt + + if maxstr := m.group("max"): + if (not isinstance(member.typ, Primitive)) or member.cnt: + raise ValueError("',max=' may only be specified on a non-repeated atom") + member.max = parse_expr(maxstr) + else: + member.max = Expr() + + if valstr := m.group("val"): + if (not isinstance(member.typ, Primitive)) or member.cnt: + raise ValueError("',val=' may only be specified on a non-repeated atom") + member.val = parse_expr(valstr) + else: + member.val = Expr() + + struct.members += [member] + + +def re_string(grpname: str) -> str: + return f'"(?P<{grpname}>[^"]*)"' + + +re_line_version = f"version\\s+{re_string('version')}" +re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)" +re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})" +re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})" +re_line_bitfield_ = f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}" +re_line_struct = ( + f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}" +) +re_line_msg = ( + f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}" +) +re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg + + +def parse_file( + filename: str, get_include: Callable[[str], tuple[str, list[Type]]] +) -> tuple[str, list[Type]]: + version: str | None = None + env: dict[str, Type] = { + "1": Primitive.u8, + "2": Primitive.u16, + "4": Primitive.u32, + "8": Primitive.u64, + } + + def get_type(name: str, tc: type[T]) -> T: + nonlocal env + if name not in env: + raise NameError(f"Unknown type {repr(name)}") + ret = env[name] + if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): + raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}") + return ret + + with open(filename, "r") as fh: + prev: Type | None = None + for line in fh: + line = line.split("#", 1)[0].rstrip() + if not line: + continue + if m := re.fullmatch(re_line_version, line): + if version: + raise SyntaxError("must have exactly 1 version line") + version = m.group("version") + continue + if not version: + raise SyntaxError("must have exactly 1 version line") + + if m := re.fullmatch(re_line_import, line): + other_version, other_typs = get_include(m.group("file")) + for symname in m.group("syms").split(sep=","): + symname = symname.strip() + for typ in other_typs: + if typ.name == symname or symname == "*": + match typ: + case Primitive(): + pass + case Number(): + typ.in_versions.add(version) + case Bitfield(): + typ.in_versions.add(version) + for val in typ.names.values(): + if other_version in val.in_versions: + val.in_versions.add(version) + case Struct(): # and Message() + typ.in_versions.add(version) + for member in typ.members: + if other_version in member.in_versions: + member.in_versions.add(version) + env[typ.name] = typ + elif m := re.fullmatch(re_line_num, line): + num = Number() + num.name = m.group("name") + num.in_versions.add(version) + + prim = env[m.group("prim")] + assert isinstance(prim, Primitive) + num.prim = prim + + env[num.name] = num + prev = num + elif m := re.fullmatch(re_line_bitfield, line): + bf = Bitfield() + bf.name = m.group("name") + + prim = env[m.group("prim")] + assert isinstance(prim, Primitive) + bf.prim = prim + + bf.bits = (prim.static_size * 8) * [""] + + env[bf.name] = bf + prev = bf + elif m := re.fullmatch(re_line_bitfield_, line): + bf = get_type(m.group("name"), Bitfield) + parse_bitspec(version, bf, m.group("member")) + + prev = bf + elif m := re.fullmatch(re_line_struct, line): + match m.group("op"): + case "=": + struct = Struct() + struct.name = m.group("name") + struct.in_versions.add(version) + struct.members = [] + parse_members(version, env, struct, m.group("members")) + + env[struct.name] = struct + prev = struct + case "+=": + struct = get_type(m.group("name"), Struct) + parse_members(version, env, struct, m.group("members")) + + prev = struct + elif m := re.fullmatch(re_line_msg, line): + match m.group("op"): + case "=": + msg = Message() + msg.name = m.group("name") + msg.in_versions.add(version) + msg.members = [] + parse_members(version, env, msg, m.group("members")) + + env[msg.name] = msg + prev = msg + case "+=": + msg = get_type(m.group("name"), Message) + parse_members(version, env, msg, m.group("members")) + + prev = msg + elif m := re.fullmatch(re_line_cont, line): + match prev: + case Bitfield(): + parse_bitspec(version, prev, m.group("specs")) + case Struct(): # and Message() + parse_members(version, env, prev, m.group("specs")) + case _: + raise SyntaxError( + "continuation line must come after a bitfield, struct, or msg line" + ) + else: + raise SyntaxError(f"invalid line {repr(line)}") + if not version: + raise SyntaxError("must have exactly 1 version line") + + typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] + + for typ in [typ for typ in typs if isinstance(typ, Struct)]: + valid_syms = ["end", *["&" + m.name for m in typ.members]] + for member in typ.members: + for tok in [*member.max.tokens, *member.val.tokens]: + if isinstance(tok, ExprSym) and tok.name not in valid_syms: + raise ValueError( + f"{typ.name}.{member.name}: invalid sym: {tok.name}" + ) + + return version, typs + + +# Generate C ################################################################### + + +def c_ver_enum(idprefix: str, ver: str) -> str: + return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" + + +def c_ver_ifdef(idprefix: str, versions: set[str]) -> str: + return " || ".join( + f"defined(CONFIG_{idprefix.upper()}ENABLE_{c_ver_enum('', v)})" + for v in sorted(versions) + ) + + +def c_ver_cond(idprefix: str, versions: set[str]) -> str: + if len(versions) == 1: + return f"(ctx->ctx->version=={c_ver_enum(idprefix, next(v for v in versions))})" + return ( + "( " + (" || ".join(c_ver_cond(idprefix, {v}) for v in sorted(versions))) + " )" + ) + + +def c_typename(idprefix: str, typ: Type) -> str: + match typ: + case Primitive(): + return f"uint{typ.value*8}_t" + case Number(): + return f"{idprefix}{typ.name}_t" + case Bitfield(): + return f"{idprefix}{typ.name}_t" + case Message(): + return f"struct {idprefix}msg_{typ.name}" + case Struct(): + return f"struct {idprefix}{typ.name}" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def gen_h(idprefix: str, versions: set[str], typs: list[Type]) -> str: + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#ifndef _LIB9P_9P_H_ +# error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead +#endif + +#include <stdint.h> /* for uint{{n}}_t types */ +""" + + _ifdef: list[str] = [] + + def push_ifdef(v: str) -> None: + nonlocal _ifdef + nonlocal ret + ret += f"#if {v}\n" + _ifdef += [v] + + def pop_ifdef(n: int) -> None: + nonlocal _ifdef + nonlocal ret + while len(_ifdef) > n: + ret += f"#endif /* {_ifdef[-1]}\n" + _ifdef = _ifdef[:-1] + + def set_ifdef(v: str) -> None: + nonlocal _ifdef + nonlocal ret + if v != _ifdef[-1]: + ret += f"#elif {v}\n" + _ifdef[-1] = v + + def pushorset_ifdef(n: int, v: str) -> None: + nonlocal _ifdef + nonlocal ret + if len(_ifdef) < n: + push_ifdef(v) + else: + set_ifdef(v) + + ret += f""" +/* versions *******************************************************************/ + +enum {idprefix}version {{ +""" + fullversions = ["unknown = 0", *sorted(versions)] + verwidth = max(len(v) for v in fullversions) + for ver in fullversions: + if ver in versions: + pushorset_ifdef(1, c_ver_ifdef(idprefix, {ver})) + ret += f"\t{c_ver_enum(idprefix, ver)}," + ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' + pop_ifdef(0) + ret += f"\t{c_ver_enum(idprefix, 'NUM')},\n" + ret += "};\n" + ret += "\n" + ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" + + ret += """ +/* non-message types **********************************************************/ +""" + for typ in [typ for typ in typs if not isinstance(typ, Message)]: + ret += "\n" + pushorset_ifdef(1, c_ver_ifdef(idprefix, typ.in_versions)) + match typ: + case Number(): + ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n" + case Bitfield(): + ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n" + names = [ + *reversed( + [typ.bits[n] or f"_UNUSED_{n}" for n in range(0, len(typ.bits))] + ), + "", + *[k for k in typ.names if k not in typ.bits], + ] + namewidth = max(len(name) for name in names) + + ret += "\n" + for name in names: + if name == "": + ret += "\n" + continue + pushorset_ifdef( + 2, c_ver_ifdef(idprefix, typ.names[name].in_versions) + ) + if name.startswith("_"): + c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}" + else: + c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}" + if name in typ.names: + val = typ.names[name].val + else: + assert name.startswith("_UNUSED_") + val = f"1<<{name[len('_UNUSED_'):]}" + ret += f"#define {c_name}{' '*(namewidth-len(name))} (({c_typename(idprefix, typ)})({val}))\n" + pop_ifdef(1) + case Struct(): + typewidth = max(len(c_typename(idprefix, m.typ)) for m in typ.members) + + ret += c_typename(idprefix, typ) + " {\n" + for member in typ.members: + if member.val: + continue + pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions)) + c_type = c_typename(idprefix, member.typ) + if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL + c_type = "char" + ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + pop_ifdef(1) + ret += "};\n" + pop_ifdef(0) + + ret += """ +/* messages *******************************************************************/ + +""" + ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message)) + for msg in [msg for msg in typs if isinstance(msg, Message)]: + pushorset_ifdef(1, c_ver_ifdef(idprefix, msg.in_versions)) + ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" + pop_ifdef(0) + ret += "};\n" + ret += "\n" + ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" + + for msg in [msg for msg in typs if isinstance(msg, Message)]: + ret += "\n" + pushorset_ifdef(1, c_ver_ifdef(idprefix, msg.in_versions)) + ret += c_typename(idprefix, msg) + " {" + if not msg.members: + ret += "};\n" + continue + ret += "\n" + + typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) + + for member in msg.members: + if member.val: + continue + pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions)) + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + pop_ifdef(1) + ret += "};\n" + pop_ifdef(0) + + return ret + + +def c_expr(expr: Expr) -> str: + ret: list[str] = [] + for tok in expr.tokens: + match tok: + case ExprOp(): + ret += [tok.op] + case ExprLit(): + ret += [str(tok.val)] + case ExprSym(name="end"): + ret += ["ctx->net_offset"] + case ExprSym(): + ret += [f"_{tok.name[1:]}_offset"] + return " ".join(ret) + + +def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str: + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#include <assert.h> +#include <stdbool.h> +#include <stddef.h> /* for size_t */ +#include <inttypes.h> /* for PRI* macros */ +#include <string.h> /* for memset() */ + +#include <lib9p/9p.h> + +#include "internal.h" +""" + + _ifdef: list[str] = [] + + def push_ifdef(v: str) -> None: + nonlocal _ifdef + nonlocal ret + ret += f"#if {v}\n" + _ifdef += [v] + + def pop_ifdef(n: int) -> None: + nonlocal _ifdef + nonlocal ret + while len(_ifdef) > n: + ret += f"#endif /* {_ifdef[-1]}\n" + _ifdef = _ifdef[:-1] + + def set_ifdef(v: str) -> None: + nonlocal _ifdef + nonlocal ret + if v != _ifdef[-1]: + ret += f"#elif {v}\n" + _ifdef[-1] = v + + def pushorset_ifdef(n: int, v: str) -> None: + nonlocal _ifdef + nonlocal ret + if len(_ifdef) < n: + push_ifdef(v) + else: + set_ifdef(v) + + def used(arg: str) -> str: + return arg + + def unused(arg: str) -> str: + return f"UNUSED({arg})" + + # strings ################################################################## + ret += f""" +/* strings ********************************************************************/ + +static const char *version_strs[{c_ver_enum(idprefix, 'NUM')}] = {{ +""" + for ver in ["unknown", *sorted(versions)]: + pushorset_ifdef(1, c_ver_ifdef(idprefix, {ver})) + ret += f'\t[{c_ver_enum(idprefix, ver)}] = "{ver}",\n' + pop_ifdef(0) + ret += "};\n" + ret += f""" +const char *{idprefix}version_str(enum {idprefix}version ver) {{ + assert(0 <= ver && ver < {c_ver_enum(idprefix, 'NUM')}); + return version_strs[ver]; +}} + +static const char *msg_type_strs[0x100] = {{ +""" + id2name: dict[int, str] = {} + for msg in [msg for msg in typs if isinstance(msg, Message)]: + id2name[msg.msgid] = msg.name + for n in range(0, 0x100): + ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) + ret += "};\n" + ret += f""" +const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ + assert(0 <= typ && typ <= 0xFF); + return msg_type_strs[typ]; +}} +""" + + # validate_* ############################################################### + ret += """ +/* validate_* *****************************************************************/ + +static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { + if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) + /* If needed-net-size overflowed uint32_t, then + * there's no way that actual-net-size will live up to + * that. */ + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + if (ctx->net_offset > ctx->net_size) + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + return false; +} + +static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { + if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) + /* If needed-host-size overflowed size_t, then there's + * no way that actual-net-size will live up to + * that. */ + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + return false; +} + +static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, + size_t cnt, + _validate_fn_t item_fn, size_t item_host_size) { + for (size_t i = 0; i < cnt; i++) + if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) + return true; + return false; +} + +#define validate_1(ctx) _validate_size_net(ctx, 1) +#define validate_2(ctx) _validate_size_net(ctx, 2) +#define validate_4(ctx) _validate_size_net(ctx, 4) +#define validate_8(ctx) _validate_size_net(ctx, 8) +""" + for typ in typs: + inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE" + argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + ret += "\n" + pushorset_ifdef(1, c_ver_ifdef(idprefix, typ.in_versions)) + ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" + + if typ.name == "d": # SPECIAL + # Optimize... maybe the compiler could figure out to do + # this, but let's make it obvious. + ret += "\tuint32_t base_offset = ctx->net_offset;\n" + ret += "\tif (validate_4(ctx))\n" + ret += "\t\treturn true;\n" + ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" + ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" + ret += "}\n" + continue + if typ.name == "s": # SPECIAL + # Add an extra nul-byte on the host, and validate UTF-8 + # (also, similar optimization to "d"). + ret += "\tuint32_t base_offset = ctx->net_offset;\n" + ret += "\tif (validate_2(ctx))\n" + ret += "\t\treturn true;\n" + ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" + ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" + ret += "\t\treturn true;\n" + ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" + ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' + ret += "\treturn false;\n" + ret += "}\n" + continue + + match typ: + case Number(): + ret += f"\treturn validate_{typ.prim.name}(ctx);\n" + case Bitfield(): + ret += f"\t if (validate_{typ.static_size}(ctx))\n" + ret += "\t\treturn true;\n" + ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_ver_enum(idprefix, 'NUM')}] = {{\n" + verwidth = max(len(ver) for ver in versions) + for ver in sorted(versions): + pushorset_ifdef(2, c_ver_ifdef(idprefix, {ver})) + ret += ( + f"\t\t[{c_ver_enum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b" + + "".join( + "1" if typ.bit_is_valid(bitname, ver) else "0" + for bitname in reversed(typ.bits) + ) + + ",\n" + ) + pop_ifdef(1) + ret += "\t};\n" + ret += ( + f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n" + ) + ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += f"\tif (val & ~mask)\n" + ret += "\t\treturn lib9p_errorf(ctx->ctx,\n" + ret += f'\t\t LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n' + ret += "\t\t val & ~mask);\n" + ret += "\treturn false;\n" + case Struct(): # and Message() + if len(typ.members) == 0: + ret += "\treturn false;\n" + ret += "}\n" + continue + + # Pass 1 + for member in typ.members: + pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions)) + if member.max or member.val: + ret += f"\t{c_typename(idprefix, member.typ)} {member.name};\n" + pop_ifdef(1) + + # Pass 2 + mark_offset: set[str] = set() + for member in typ.members: + for tok in [*member.max.tokens, *member.val.tokens]: + if isinstance(tok, ExprSym) and tok.name.startswith("&"): + if tok.name[1:] not in mark_offset: + ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" + mark_offset.add(tok.name[1:]) + + # Pass 3 + ret += "\treturn false\n" + prev_size: int | None = None + for member in typ.members: + pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions)) + ret += f"\n\t|| " + if member.in_versions != typ.in_versions: + ret += "( " + c_ver_cond(idprefix, member.in_versions) + " && " + if member.cnt is not None: + assert prev_size + ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" + else: + if member.max or member.val: + ret += "(" + if member.name in mark_offset: + ret += f"({{ _{member.name}_offset = ctx->net_offset; " + ret += f"validate_{member.typ.name}(ctx)" + if member.name in mark_offset: + ret += "; })" + if member.max or member.val: + bytes = member.static_size + assert bytes + bits = bytes * 8 + ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" + if member.in_versions != typ.in_versions: + ret += " )" + prev_size = member.static_size + + # Pass 4 + for member in typ.members: + if member.max: + pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions)) + ret += f"\n\t|| ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n" + ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n' + if member.val: + pushorset_ifdef(2, c_ver_ifdef(idprefix, member.in_versions)) + ret += f"\n\t|| ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n" + ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n' + + pop_ifdef(1) + ret += "\t;\n" + ret += "}\n" + pop_ifdef(0) + + # # unmarshal_* ############################################################## + # ret += """ + # /* unmarshal_* ****************************************************************/ + + # static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { + # *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 1; + # } + + # static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { + # *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 2; + # } + + # static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { + # *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 4; + # } + + # static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { + # *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 8; + # } + # """ + # for typ in typs: + # inline = ( + # " FLATTEN" + # if (isinstance(typ, Struct) and typ.msgid is not None) + # else " ALWAYS_INLINE" + # ) + # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + # ret += "\n" + # ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" + # match typ: + # case Bitfield(): + # ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" + # case Struct(): + # ret += "\tmemset(out, 0, sizeof(*out));\n" + + # if typ.members: + # struct_versions = typ.members[0].ver + # for member in typ.members: + # if member.valexpr: + # ret += f"\tctx->net_offset += {member.static_size};\n" + # continue + # ret += "\t" + # prefix = "\t" + # if member.ver != struct_versions: + # ret += "if ( " + c_ver_cond(idprefix, member.ver) + " ) " + # prefix = "\t\t" + # if member.cnt: + # if member.ver != struct_versions: + # ret += f"{{\n{prefix}" + # ret += f"out->{member.name} = ctx->extra;\n" + # ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" + # ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + # if typ.name in ["d", "s"]: # SPECIAL + # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" + # else: + # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + # if member.ver != struct_versions: + # ret += "\t}\n" + # else: + # ret += ( + # f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" + # ) + # if typ.name == "s": # SPECIAL + # ret += "\tctx->extra++;\n" + # ret += "\tout->utf8[out->len] = '\\0';\n" + # ret += "}\n" + + # # marshal_* ################################################################ + # ret += """ + # /* marshal_* ******************************************************************/ + + # static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { + # lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", + # (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", + # ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), + # ctx->ctx->max_msg_size); + # return true; + # } + + # static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { + # if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) + # return _marshal_too_large(ctx); + # ctx->net_bytes[ctx->net_offset] = *val; + # ctx->net_offset += 1; + # return false; + # } + + # static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { + # if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) + # return _marshal_too_large(ctx); + # encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 2; + # return false; + # } + + # static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { + # if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) + # return true; + # encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 4; + # return false; + # } + + # static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { + # if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) + # return true; + # encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); + # ctx->net_offset += 8; + # return false; + # } + # """ + # for typ in typs: + # inline = ( + # " FLATTEN" + # if (isinstance(typ, Struct) and typ.msgid is not None) + # else " ALWAYS_INLINE" + # ) + # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used + # ret += "\n" + # ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" + # match typ: + # case Bitfield(): + # ret += "\n" + # ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" + # case Struct(): + # if len(typ.members) == 0: + # ret += "\n\treturn false;\n" + # ret += "}\n" + # continue + + # mark_offset = set() + # for member in typ.members: + # if member.valexpr: + # if member.name not in mark_offset: + # ret += f"\n\tuint32_t _{member.name}_offset;" + # mark_offset.add(member.name) + # for tok in member.valexpr: + # if ( + # isinstance(tok, ExprVal) + # and tok.name.startswith("&") + # and tok.name[1:] not in mark_offset + # ): + # ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" + # mark_offset.add(tok.name[1:]) + + # prefix0 = "\treturn " + # prefix1 = "\t || " + # prefix2 = "\t " + + # struct_versions = typ.members[0].ver + # prefix = prefix0 + # for member in typ.members: + # ret += f"\n{prefix}" + # if member.ver != struct_versions: + # ret += "( " + c_ver_cond(idprefix, member.ver) + " && " + # if member.name in mark_offset: + # ret += f"({{ _{member.name}_offset = ctx->net_offset; " + # if member.cnt: + # ret += "({" + # ret += f"\n{prefix2}\tbool err = false;" + # ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" + # if typ.name in ["d", "s"]: # SPECIAL + # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);" + # else: + # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" + # ret += f"\n{prefix2}\terr;" + # ret += f"\n{prefix2}}})" + # elif member.valexpr: + # assert member.static_size + # ret += ( + # f"({{ ctx->net_offset += {member.static_size}; false; }})" + # ) + # else: + # ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" + # if member.name in mark_offset: + # ret += "; })" + # if member.ver != struct_versions: + # ret += " )" + # prefix = prefix1 + + # for member in typ.members: + # if member.valexpr: + # assert member.static_size + # ret += f"\n{prefix}" + # ret += f"({{ encode_u{member.static_size*8}le(" + # for tok in member.valexpr: + # match tok: + # case ExprOp(): + # ret += f" {tok.op}" + # case ExprVal(name="end"): + # ret += " ctx->net_offset" + # case ExprVal(): + # ret += f" _{tok.name[1:]}_offset" + # ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})" + + # ret += ";\n" + # ret += "}\n" + + # # vtables ################################################################## + # ret += f""" + # /* vtables ********************************************************************/ + + # #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ + # .basesize = sizeof(struct {idprefix}msg_##typ), \\ + # .validate = validate_##typ, \\ + # .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ + # .marshal = (_marshal_fn_t)marshal_##typ, \\ + # }} + + # struct _vtable_version _{idprefix}vtables[{c_ver_enum(idprefix, 'NUM')}] = {{ + # """ + + # ret += f"\t[{c_ver_enum(idprefix, 'unknown')}] = {{ .msgs = {{\n" + # for msg in just_structs_msg(typs): + # if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL + # ret += f"\t\t_MSG({msg.name}),\n" + # ret += "\t}},\n" + + # for ver in sorted(versions): + # ret += f"\t[{c_ver_enum(idprefix, ver)}] = {{ .msgs = {{\n" + # for msg in just_structs_msg(typs): + # if ver not in msg.msgver: + # continue + # ret += f"\t\t_MSG({msg.name}),\n" + # ret += "\t}},\n" + # ret += "};\n" + + ############################################################################ + return ret + + +################################################################################ + + +class Parser: + cache: dict[str, tuple[str, list[Type]]] = {} + + def parse_file(self, filename: str) -> tuple[str, list[Type]]: + filename = os.path.normpath(filename) + if filename not in self.cache: + + def get_include(other_filename: str) -> tuple[str, list[Type]]: + return self.parse_file(os.path.join(filename, "..", other_filename)) + + self.cache[filename] = parse_file(filename, get_include) + return self.cache[filename] + + def all(self) -> tuple[set[str], list[Type]]: + ret_versions: set[str] = set() + ret_typs: dict[str, Type] = {} + for version, typs in self.cache.values(): + if version in ret_versions: + raise ValueError(f"duplicate protocol version {repr(version)}") + ret_versions.add(version) + for typ in typs: + if typ.name in ret_typs: + if typ != ret_typs[typ.name]: + raise ValueError(f"duplicate type name {repr(typ.name)}") + else: + ret_typs[typ.name] = typ + return ret_versions, list(ret_typs.values()) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + raise ValueError("requires at least 1 .9p filename") + parser = Parser() + for txtname in sys.argv[1:]: + parser.parse_file(txtname) + versions, typs = parser.all() + outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) + with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: + fh.write(gen_h("lib9p_", versions, typs)) + with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: + fh.write(gen_c("lib9p_", versions, typs)) |