diff options
Diffstat (limited to 'lib9p/9p.gen')
-rwxr-xr-x | lib9p/9p.gen | 1059 |
1 files changed, 0 insertions, 1059 deletions
diff --git a/lib9p/9p.gen b/lib9p/9p.gen deleted file mode 100755 index 816ec0a..0000000 --- a/lib9p/9p.gen +++ /dev/null @@ -1,1059 +0,0 @@ -#!/usr/bin/env python -# lib9p/9p.gen - Generate C marshalers/unmarshalers for .txt files -# defining 9P protocol variants. -# -# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> -# SPDX-Licence-Identifier: AGPL-3.0-or-later - -import enum -import os.path -import re -from typing import Callable, Literal, Sequence - -# This strives to be "general-purpose" in that it just acts on the -# *.txt inputs; but (unfortunately?) there are a few special-cases in -# this script, marked with "SPECIAL". - -# Parse *.txt ################################################################## - - -class Atom(enum.Enum): - u8 = 1 - u16 = 2 - u32 = 4 - u64 = 8 - - @property - def name(self) -> str: - return str(self.value) - - @property - def static_size(self) -> int: - return self.value - - -class BitfieldVal: - name: str - val: str - ver: set[str] - - def __init__(self) -> None: - self.ver = set() - - -class Bitfield: - name: str - bits: list[str] - names: dict[str, BitfieldVal] - - @property - def static_size(self) -> int: - return int((len(self.bits) + 7) / 8) - - def bitname_is_valid(self, bitname: str, ver: str | None = None) -> bool: - assert bitname in self.bits - if not bitname: - return False - if bitname.startswith("_"): - return False - if ver and (ver not in self.names[bitname].ver): - return False - return True - - -# `msgid/structname = "member1 member2..."` -# `structname = "member1 member2..."` -# `structname += "member1 member2..."` -class Struct: - msgid: int | None = None - msgver: set[str] - name: str - members: list["Member"] - - def __init__(self) -> None: - self.msgver = set() - - @property - def static_size(self) -> int | None: - size = 0 - for member in self.members: - msize = member.static_size - if msize is None: - return None - size += msize - return size - - -class ExprVal: - name: str - - def __init__(self, name: str) -> None: - self.name = name - - -class ExprOp: - op: Literal["-", "+"] - - def __init__(self, op: Literal["-", "+"]) -> None: - self.op = op - - -# `cnt*(name[typ])` -# the `cnt*(...)` wrapper is optional -class Member: - cnt: str | None = None - name: str - typ: Atom | Bitfield | Struct - max: int | None = None - valexpr: list[ExprVal | ExprOp] = [] - ver: set[str] - - @property - def static_size(self) -> int | None: - if self.cnt: - return None - return self.typ.static_size - - -def parse_valexpr(valexpr: str) -> list[ExprVal | ExprOp]: - ret: list[ExprVal | ExprOp] = [] - for tok in re.split("([-+])", valexpr): - match tok: - case "-": - ret += [ExprOp(tok)] - case "+": - ret += [ExprOp(tok)] - case _: - ret += [ExprVal(tok)] - return ret - - -re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" -re_memberspec = f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>[^,]*)(?:,max=(?P<max>[0-9]+)|,val=(?P<val>[-+&a-zA-Z0-9_]+))*\\]\\)?" - - -def parse_members( - ver: str, - env: dict[str, Atom | Bitfield | Struct], - existing: list[Member], - specs: str, -) -> list[Member]: - ret = existing - for spec in specs.split(): - m = re.fullmatch(re_memberspec, spec) - if not m: - raise SyntaxError(f"invalid member spec {repr(spec)}") - - member = Member() - member.ver = {ver} - - member.name = m.group("name") - if any(x.name == member.name for x in ret): - raise ValueError(f"duplicate member name {repr(member.name)}") - - if m.group("typ") not in env: - raise NameError(f"Unknown type {repr(m.group(2))}") - member.typ = env[m.group("typ")] - - if cnt := m.group("cnt"): - if len(ret) == 0 or ret[-1].name != cnt: - raise ValueError(f"list count must be previous item: {repr(cnt)}") - if not isinstance(ret[-1].typ, Atom): - raise ValueError(f"list count must be an integer type: {repr(cnt)}") - member.cnt = cnt - - if maxstr := m.group("max"): - if (not isinstance(member.typ, Atom)) or member.cnt: - raise ValueError("',max=' may only be specified on a non-repeated atom") - member.max = int(maxstr) - - if valstr := m.group("val"): - if (not isinstance(member.typ, Atom)) or member.cnt: - raise ValueError("',val=' may only be specified on a non-repeated atom") - member.valexpr = parse_valexpr(valstr) - - ret += [member] - return ret - - -re_version = r'version\s+"(?P<version>[^"]+)"' -re_import = r"from\s+(?P<file>\S+)\s+import\s+(?P<syms>\S+(?:\s*,\s*\S+)*)\s*" -re_structspec = ( - r'(?:(?P<msgid>[0-9]+)/)?(?P<name>\S+)\s*(?P<op>\+?=)\s*"(?P<members>[^"]*)"' -) -re_structspec_cont = r'\s+"(?P<members>[^"]*)"' -re_bitfieldspec = r"bitfield\s+(?P<name>\S+)\s+(?P<size>[0-9]+)" -re_bitfieldspec_bit = r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<bit>[0-9]+)/(?P<name>\S+)" -re_bitfieldspec_alias = ( - r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<name>\S+)\s*=\s*(?P<val>.*)" -) - - -def parse_file( - filename: str, get_include: Callable[[str], tuple[str, list[Bitfield | Struct]]] -) -> tuple[str, list[Bitfield | Struct]]: - version: str | None = None - env: dict[str, Atom | Bitfield | Struct] = { - "1": Atom.u8, - "2": Atom.u16, - "4": Atom.u32, - "8": Atom.u64, - } - with open(filename, "r") as fh: - prev: Struct | Bitfield | None = None - for line in fh: - line = line.split("#", 1)[0].rstrip() - if not line: - continue - if m := re.fullmatch(re_version, line): - if version: - raise SyntaxError("must have exactly 1 version line") - version = m.group("version") - elif m := re.fullmatch(re_import, line): - if not version: - raise SyntaxError("must have exactly 1 version line") - other_version, other_typs = get_include(m.group("file")) - for symname in m.group("syms").split(sep=","): - symname = symname.strip() - for typ in other_typs: - if typ.name == symname or symname == "*": - match typ: - case Bitfield(): - for val in typ.names.values(): - if other_version in val.ver: - val.ver.add(version) - case Struct(): - if typ.msgid: - typ.msgver.add(version) - for member in typ.members: - if other_version in member.ver: - member.ver.add(version) - env[typ.name] = typ - elif m := re.fullmatch(re_structspec, line): - if not version: - raise SyntaxError("must have exactly 1 version line") - if m.group("op") == "+=" and m.group("msgid"): - raise SyntaxError("cannot += to a message that is not yet defined") - match m.group("op"): - case "=": - struct = Struct() - if m.group("msgid"): - struct.msgid = int(m.group("msgid")) - struct.msgver.add(version) - struct.name = m.group("name") - struct.members = parse_members( - version, env, [], m.group("members") - ) - env[struct.name] = struct - prev = struct - case "+=": - if m.group("name") not in env: - raise NameError(f"Unknown type {repr(m.group('name'))}") - _struct = env[m.group("name")] - if not isinstance(_struct, Struct): - raise NameError( - f"Type {repr(_struct.name)} is not a struct" - ) - struct = _struct - struct.members = parse_members( - version, env, struct.members, m.group("members") - ) - prev = struct - elif m := re.fullmatch(re_structspec_cont, line): - if not isinstance(prev, Struct): - raise SyntaxError( - "struct-continuation line must come after a struct line" - ) - assert version - prev.members = parse_members( - version, env, prev.members, m.group("members") - ) - elif m := re.fullmatch(re_bitfieldspec, line): - if not version: - raise SyntaxError("must have exactly 1 version line") - bf = Bitfield() - bf.name = m.group("name") - bf.bits = int(m.group("size")) * [""] - bf.names = {} - if len(bf.bits) not in [8, 16, 32, 64]: - raise ValueError(f"Bitfield {repr(bf.name)} has an unusual size") - env[bf.name] = bf - prev = bf - elif m := re.fullmatch(re_bitfieldspec_bit, line): - if m.group("bitfield"): - if m.group("bitfield") not in env: - raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") - _bf = env[m.group("bitfield")] - if not isinstance(_bf, Bitfield): - raise NameError(f"Type {repr(_bf.name)} is not a bitfield") - bf = _bf - prev = bf - else: - if not isinstance(prev, Bitfield): - raise SyntaxError( - "bitfield-continuation line must come after a bitfield line" - ) - bf = prev - bit = int(m.group("bit")) - name = m.group("name") - if bit < 0 or bit >= len(bf.bits): - raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") - if bf.bits[bit]: - raise ValueError(f"{bf.name}: bit {bit} already assigned") - if name in bf.names: - raise ValueError(f"{bf.name}: name {name} already assigned") - - bf.bits[bit] = name - - assert version - val = BitfieldVal() - val.name = name - val.val = f"1<<{bit}" - val.ver.add(version) - bf.names[name] = val - elif m := re.fullmatch(re_bitfieldspec_alias, line): - if m.group("bitfield"): - if m.group("bitfield") not in env: - raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}") - _bf = env[m.group("bitfield")] - if not isinstance(_bf, Bitfield): - raise NameError(f"Type {repr(_bf.name)} is not a bitfield") - bf = _bf - prev = bf - else: - if not isinstance(prev, Bitfield): - raise SyntaxError( - "bitfield-continuation line must come after a bitfield line" - ) - bf = prev - name = m.group("name") - valstr = m.group("val") - if name in bf.names: - raise ValueError(f"{bf.name}: name {name} already assigned") - - assert version - val = BitfieldVal() - val.name = name - val.val = valstr - val.ver.add(version) - bf.names[name] = val - else: - raise SyntaxError(f"invalid line {repr(line)}") - if not version: - raise SyntaxError("must have exactly 1 version line") - - typs = [x for x in env.values() if not isinstance(x, Atom)] - - for typ in just_structs_all(typs): - valid_vals = ["end", *["&" + m.name for m in typ.members]] - for member in typ.members: - for tok in member.valexpr: - if isinstance(tok, ExprVal) and tok.name not in valid_vals: - raise ValueError( - f"{typ.name}.{member.name}: invalid val: {tok.name}" - ) - - return version, typs - - -# Generate C ################################################################### - - -def c_typename(idprefix: str, typ: Atom | Bitfield | Struct) -> str: - match typ: - case Atom(): - return f"uint{typ.value*8}_t" - case Bitfield(): - return f"{idprefix}{typ.name}_t" - case Struct(): - if typ.msgid is not None: - return f"struct {idprefix}msg_{typ.name}" - return f"struct {idprefix}{typ.name}" - case _: - raise ValueError(f"not a type: {typ.__class__.__name__}") - - -def c_verenum(idprefix: str, ver: str) -> str: - return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" - - -def c_vercomment(versions: set[str]) -> str | None: - if "9P2000" in versions: - return None - return "/* " + (", ".join(sorted(versions))) + " */" - - -def c_vercond(idprefix: str, versions: set[str]) -> str: - if len(versions) == 1: - return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})" - return ( - "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )" - ) - - -def just_structs_all(typs: list[Bitfield | Struct]) -> Sequence[Struct]: - return list(typ for typ in typs if isinstance(typ, Struct)) - - -def just_structs_nonmsg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: - return list(typ for typ in typs if isinstance(typ, Struct) and typ.msgid is None) - - -def just_structs_msg(typs: list[Bitfield | Struct]) -> Sequence[Struct]: - return list( - typ for typ in typs if isinstance(typ, Struct) and typ.msgid is not None - ) - - -def just_bitfields(typs: list[Bitfield | Struct]) -> Sequence[Bitfield]: - return list(typ for typ in typs if isinstance(typ, Bitfield)) - - -def gen_h(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: - ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ - -#ifndef _LIB9P_9P_H_ -# error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead -#endif - -#include <stdint.h> /* for uint{{n}}_t types */ -""" - - ret += f""" -/* versions *******************************************************************/ - -enum {idprefix}version {{ -""" - fullversions = ["unknown = 0", *sorted(versions)] - verwidth = max(len(v) for v in fullversions) - for ver in fullversions: - ret += f"\t{c_verenum(idprefix, ver)}," - ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' - ret += f"\t{c_verenum(idprefix, 'NUM')},\n" - ret += "};\n" - ret += "\n" - ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n" - - ret += """ -/* non-message types **********************************************************/ -""" - for bf in just_bitfields(typs): - ret += "\n" - ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n" - names = [ - *reversed([bf.bits[n] or f"_UNUSED_{n}" for n in range(0, len(bf.bits))]), - "", - *[k for k in bf.names if k not in bf.bits], - ] - namewidth = max(len(name) for name in names) - - ret += "\n" - for name in names: - if name == "": - ret += "\n" - continue - if name.startswith("_"): - cname = f"_{idprefix.upper()}{bf.name.upper()}_{name[1:]}" - else: - cname = f"{idprefix.upper()}{bf.name.upper()}_{name}" - if name in bf.names: - val = bf.names[name].val - else: - assert name.startswith("_UNUSED_") - val = f"1<<{name[len('_UNUSED_'):]}" - ret += f"#define {cname}{' '*(namewidth-len(name))} (({c_typename(idprefix, bf)})({val}))" - if (name in bf.names) and (comment := c_vercomment(bf.names[name].ver)): - ret += " " + comment - ret += "\n" - - for struct in just_structs_nonmsg(typs): - all_the_same = len(struct.members) == 0 or all( - m.ver == struct.members[0].ver for m in struct.members - ) - typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members) - if not all_the_same: - namewidth = max(len(m.name) for m in struct.members) - - ret += "\n" - ret += c_typename(idprefix, struct) + " {\n" - for member in struct.members: - if member.valexpr: - continue - ctype = c_typename(idprefix, member.typ) - if (struct.name in ["d", "s"]) and member.cnt: # SPECIAL - ctype = "char" - ret += f"\t{ctype.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" - if (not all_the_same) and (comment := c_vercomment(member.ver)): - ret += (" " * (namewidth - len(member.name))) + " " + comment - ret += "\n" - ret += "};\n" - - ret += """ -/* messages *******************************************************************/ - -""" - ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" - namewidth = max(len(msg.name) for msg in just_structs_msg(typs)) - for msg in just_structs_msg(typs): - ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid}," - if comment := c_vercomment(msg.msgver): - ret += " " + comment - ret += "\n" - ret += "};\n" - ret += "\n" - ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n" - - for msg in just_structs_msg(typs): - ret += "\n" - if comment := c_vercomment(msg.msgver): - ret += comment + "\n" - ret += c_typename(idprefix, msg) + " {" - if not msg.members: - ret += "};\n" - continue - ret += "\n" - - all_the_same = len(msg.members) == 0 or all( - m.ver == msg.members[0].ver for m in msg.members - ) - typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) - if not all_the_same: - namewidth = max(len(m.name) for m in msg.members) - - for member in msg.members: - if member.valexpr: - continue - ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" - if (not all_the_same) and (comment := c_vercomment(member.ver)): - ret += (" " * (namewidth - len(member.name))) + " " + comment - ret += "\n" - ret += "};\n" - - return ret - - -def gen_c(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str: - ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ - -#include <assert.h> -#include <stdbool.h> -#include <stddef.h> /* for size_t */ -#include <inttypes.h> /* for PRI* macros */ -#include <string.h> /* for memset() */ - -#include <lib9p/9p.h> - -#include "internal.h" -""" - - def used(arg: str) -> str: - return arg - - def unused(arg: str) -> str: - return f"UNUSED({arg})" - - # strings ################################################################## - ret += f""" -/* strings ********************************************************************/ - -static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{ -""" - for ver in ["unknown", *sorted(versions)]: - ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n' - ret += "};\n" - ret += f""" -const char *{idprefix}version_str(enum {idprefix}version ver) {{ - assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')}); - return version_strs[ver]; -}} - -static const char *msg_type_strs[0x100] = {{ -""" - id2name: dict[int, str] = {} - for msg in just_structs_msg(typs): - assert msg.msgid - id2name[msg.msgid] = msg.name - for n in range(0, 0x100): - ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n))) - ret += "};\n" - ret += f""" -const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{ - assert(0 <= typ && typ <= 0xFF); - return msg_type_strs[typ]; -}} -""" - - # validate_* ############################################################### - ret += """ -/* validate_* *****************************************************************/ - -static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { - if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) - /* If needed-net-size overflowed uint32_t, then - * there's no way that actual-net-size will live up to - * that. */ - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); - if (ctx->net_offset > ctx->net_size) - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); - return false; -} - -static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { - if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) - /* If needed-host-size overflowed size_t, then there's - * no way that actual-net-size will live up to - * that. */ - return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); - return false; -} - -static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx, - size_t cnt, - _validate_fn_t item_fn, size_t item_host_size) { - for (size_t i = 0; i < cnt; i++) - if (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) - return true; - return false; -} - -#define validate_1(ctx) _validate_size_net(ctx, 1) -#define validate_2(ctx) _validate_size_net(ctx, 2) -#define validate_4(ctx) _validate_size_net(ctx, 4) -#define validate_8(ctx) _validate_size_net(ctx, 8) -""" - for typ in typs: - inline = ( - " FLATTEN" - if (isinstance(typ, Struct) and typ.msgid is not None) - else " ALWAYS_INLINE" - ) - argfn = unused if (isinstance(typ, Struct) and not typ.members) else used - ret += "\n" - ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{" - - if typ.name == "d": # SPECIAL - # Optimize... maybe the compiler could figure out to do - # this, but let's make it obvious. - ret += "\n" - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_4(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" - ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n" - ret += "}\n" - continue - if typ.name == "s": # SPECIAL - # Add an extra nul-byte on the host, and validate UTF-8 - # (also, similar optimization to "d"). - ret += "\n" - ret += "\tuint32_t base_offset = ctx->net_offset;\n" - ret += "\tif (validate_2(ctx))\n" - ret += "\t\treturn true;\n" - ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" - ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n" - ret += "\t\treturn true;\n" - ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" - ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' - ret += "\treturn false;\n" - ret += "}\n" - continue - - match typ: - case Bitfield(): - ret += "\n" - all_the_same = all( - val.ver == [*typ.names.values()][0].ver - for val in typ.names.values() - ) - if ( - all_the_same - and (len(typ.bits) == typ.static_size * 8) - and all(typ.bitname_is_valid(bitname) for bitname in typ.bits) - ): - ret += f"\treturn validate_{typ.static_size}(ctx));\n" - else: - ret += f"\t if (validate_{typ.static_size}(ctx))\n" - ret += "\t\treturn true;\n" - if all_the_same: - ret += ( - f"\tstatic const {c_typename(idprefix, typ)} mask = 0b" - + "".join( - "1" if typ.bitname_is_valid(bitname) else "0" - for bitname in reversed(typ.bits) - ) - + ";\n" - ) - else: - ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_verenum(idprefix, 'NUM')}] = {{\n" - verwidth = max(len(ver) for ver in versions) - for ver in sorted(versions): - ret += ( - f"\t\t[{c_verenum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b" - + "".join( - "1" if typ.bitname_is_valid(bitname, ver) else "0" - for bitname in reversed(typ.bits) - ) - + ",\n" - ) - ret += "\t};\n" - ret += f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n" - ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" - ret += f"\tif (val & ~mask)\n" - ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n' - ret += "\t\t val & ~mask);\n" - ret += "\treturn false;\n" - case Struct(): - if len(typ.members) == 0: - ret += "\n\treturn false;\n" - ret += "}\n" - continue - - for member in typ.members: - if member.max or member.valexpr: - ret += f"\n\t{c_typename(idprefix, member.typ)} {member.name};" - mark_offset: set[str] = set() - for member in typ.members: - for tok in member.valexpr: - if ( - isinstance(tok, ExprVal) - and tok.name.startswith("&") - and tok.name[1:] not in mark_offset - ): - ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" - mark_offset.add(tok.name[1:]) - - prefix0 = "\treturn " - prefix1 = "\t || " - prefix2 = "\t " - - struct_versions = typ.members[0].ver - - prefix = prefix0 - prev_size: int | None = None - for member in typ.members: - ret += f"\n{prefix}" - if member.ver != struct_versions: - ret += "( " + c_vercond(idprefix, member.ver) + " && " - if member.cnt is not None: - assert prev_size - ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" - else: - if member.max or member.valexpr: - ret += "(" - if member.name in mark_offset: - ret += f"({{ _{member.name}_offset = ctx->net_offset; " - ret += f"validate_{member.typ.name}(ctx)" - if member.name in mark_offset: - ret += "; })" - if member.max or member.valexpr: - bytes = member.static_size - assert bytes - bits = bytes * 8 - ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))" - if member.max: - ret += f"\n{prefix1}" - ret += f'({member.name} > UINT{bits}_C({member.max}) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu{bits}" > %"PRIu{bits}")", {member.name}, UINT{bits}_C({member.max})))' - if member.ver != struct_versions: - ret += " )" - prefix = prefix1 - prev_size = member.static_size - - for member in typ.members: - if member.valexpr: - ret += f"\n{prefix}" - ret += f"({{ uint32_t correct =" - for tok in member.valexpr: - match tok: - case ExprOp(): - ret += f" {tok.op}" - case ExprVal(name="end"): - ret += " ctx->net_offset" - case ExprVal(): - ret += f" _{tok.name[1:]}_offset" - ret += f"; (((uint32_t){member.name}) != correct) &&" - ret += f'\n{prefix2}lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, correct); }})' - - ret += ";\n" - ret += "}\n" - - # unmarshal_* ############################################################## - ret += """ -/* unmarshal_* ****************************************************************/ - -static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { - *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 1; -} - -static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { - *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 2; -} - -static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { - *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 4; -} - -static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { - *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 8; -} -""" - for typ in typs: - inline = ( - " FLATTEN" - if (isinstance(typ, Struct) and typ.msgid is not None) - else " ALWAYS_INLINE" - ) - argfn = unused if (isinstance(typ, Struct) and not typ.members) else used - ret += "\n" - ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n" - match typ: - case Bitfield(): - ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n" - case Struct(): - ret += "\tmemset(out, 0, sizeof(*out));\n" - - if typ.members: - struct_versions = typ.members[0].ver - for member in typ.members: - if member.valexpr: - ret += f"\tctx->net_offset += {member.static_size};\n" - continue - ret += "\t" - prefix = "\t" - if member.ver != struct_versions: - ret += "if ( " + c_vercond(idprefix, member.ver) + " ) " - prefix = "\t\t" - if member.cnt: - if member.ver != struct_versions: - ret += f"{{\n{prefix}" - ret += f"out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" - ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" - if typ.name in ["d", "s"]: # SPECIAL - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" - else: - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" - if member.ver != struct_versions: - ret += "\t}\n" - else: - ret += ( - f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" - ) - if typ.name == "s": # SPECIAL - ret += "\tctx->extra++;\n" - ret += "\tout->utf8[out->len] = '\\0';\n" - ret += "}\n" - - # marshal_* ################################################################ - ret += """ -/* marshal_* ******************************************************************/ - -static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) { - lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", - (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", - ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), - ctx->ctx->max_msg_size); - return true; -} - -static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { - if (ctx->net_offset + 1 > ctx->ctx->max_msg_size) - return _marshal_too_large(ctx); - ctx->net_bytes[ctx->net_offset] = *val; - ctx->net_offset += 1; - return false; -} - -static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { - if (ctx->net_offset + 2 > ctx->ctx->max_msg_size) - return _marshal_too_large(ctx); - encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 2; - return false; -} - -static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { - if (ctx->net_offset + 4 > ctx->ctx->max_msg_size) - return true; - encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 4; - return false; -} - -static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { - if (ctx->net_offset + 8 > ctx->ctx->max_msg_size) - return true; - encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]); - ctx->net_offset += 8; - return false; -} -""" - for typ in typs: - inline = ( - " FLATTEN" - if (isinstance(typ, Struct) and typ.msgid is not None) - else " ALWAYS_INLINE" - ) - argfn = unused if (isinstance(typ, Struct) and not typ.members) else used - ret += "\n" - ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{" - match typ: - case Bitfield(): - ret += "\n" - ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n" - case Struct(): - if len(typ.members) == 0: - ret += "\n\treturn false;\n" - ret += "}\n" - continue - - mark_offset = set() - for member in typ.members: - if member.valexpr: - if member.name not in mark_offset: - ret += f"\n\tuint32_t _{member.name}_offset;" - mark_offset.add(member.name) - for tok in member.valexpr: - if ( - isinstance(tok, ExprVal) - and tok.name.startswith("&") - and tok.name[1:] not in mark_offset - ): - ret += f"\n\tuint32_t _{tok.name[1:]}_offset;" - mark_offset.add(tok.name[1:]) - - prefix0 = "\treturn " - prefix1 = "\t || " - prefix2 = "\t " - - struct_versions = typ.members[0].ver - prefix = prefix0 - for member in typ.members: - ret += f"\n{prefix}" - if member.ver != struct_versions: - ret += "( " + c_vercond(idprefix, member.ver) + " && " - if member.name in mark_offset: - ret += f"({{ _{member.name}_offset = ctx->net_offset; " - if member.cnt: - ret += "({" - ret += f"\n{prefix2}\tbool err = false;" - ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" - if typ.name in ["d", "s"]: # SPECIAL - ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);" - else: - ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);" - ret += f"\n{prefix2}\terr;" - ret += f"\n{prefix2}}})" - elif member.valexpr: - assert member.static_size - ret += ( - f"({{ ctx->net_offset += {member.static_size}; false; }})" - ) - else: - ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" - if member.name in mark_offset: - ret += "; })" - if member.ver != struct_versions: - ret += " )" - prefix = prefix1 - - for member in typ.members: - if member.valexpr: - assert member.static_size - ret += f"\n{prefix}" - ret += f"({{ encode_u{member.static_size*8}le(" - for tok in member.valexpr: - match tok: - case ExprOp(): - ret += f" {tok.op}" - case ExprVal(name="end"): - ret += " ctx->net_offset" - case ExprVal(): - ret += f" _{tok.name[1:]}_offset" - ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})" - - ret += ";\n" - ret += "}\n" - - # vtables ################################################################## - ret += f""" -/* vtables ********************************************************************/ - -#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\ - .basesize = sizeof(struct {idprefix}msg_##typ), \\ - .validate = validate_##typ, \\ - .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\ - .marshal = (_marshal_fn_t)marshal_##typ, \\ - }} - -struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{ -""" - - ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n" - for msg in just_structs_msg(typs): - if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL - ret += f"\t\t_MSG({msg.name}),\n" - ret += "\t}},\n" - - for ver in sorted(versions): - ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n" - for msg in just_structs_msg(typs): - if ver not in msg.msgver: - continue - ret += f"\t\t_MSG({msg.name}),\n" - ret += "\t}},\n" - ret += "};\n" - - ############################################################################ - return ret - - -################################################################################ - - -class Parser: - cache: dict[str, tuple[str, list[Bitfield | Struct]]] = {} - - def parse_file(self, filename: str) -> tuple[str, list[Bitfield | Struct]]: - filename = os.path.normpath(filename) - if filename not in self.cache: - - def get_include(other_filename: str) -> tuple[str, list[Bitfield | Struct]]: - return self.parse_file(os.path.join(filename, "..", other_filename)) - - self.cache[filename] = parse_file(filename, get_include) - return self.cache[filename] - - def all(self) -> tuple[set[str], list[Bitfield | Struct]]: - ret_versions: set[str] = set() - ret_typs: dict[str, Bitfield | Struct] = {} - for version, typs in self.cache.values(): - if version in ret_versions: - raise ValueError(f"duplicate protocol version {repr(version)}") - ret_versions.add(version) - for typ in typs: - if typ.name in ret_typs: - if typ != ret_typs[typ.name]: - raise ValueError(f"duplicate type name {repr(typ.name)}") - else: - ret_typs[typ.name] = typ - return ret_versions, list(ret_typs.values()) - - -if __name__ == "__main__": - import sys - - if len(sys.argv) < 2: - raise ValueError("requires at least 1 .txt filename") - parser = Parser() - for txtname in sys.argv[1:]: - parser.parse_file(txtname) - versions, typs = parser.all() - outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) - with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: - fh.write(gen_h("lib9p_", versions, typs)) - with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: - fh.write(gen_c("lib9p_", versions, typs)) |