summaryrefslogtreecommitdiff
path: root/lib9p/9p.gen
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/9p.gen')
-rwxr-xr-xlib9p/9p.gen1059
1 files changed, 0 insertions, 1059 deletions
diff --git a/lib9p/9p.gen b/lib9p/9p.gen
deleted file mode 100755
index 816ec0a..0000000
--- a/lib9p/9p.gen
+++ /dev/null
@@ -1,1059 +0,0 @@
-#!/usr/bin/env python
-# lib9p/9p.gen - Generate C marshalers/unmarshalers for .txt files
-# defining 9P protocol variants.
-#
-# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
-# SPDX-Licence-Identifier: AGPL-3.0-or-later
-
-import enum
-import os.path
-import re
-from typing import Callable, Literal, Sequence
-
-# This strives to be "general-purpose" in that it just acts on the
-# *.txt inputs; but (unfortunately?) there are a few special-cases in
-# this script, marked with "SPECIAL".
-
-# Parse *.txt ##################################################################
-
-
-class Atom(enum.Enum):
- u8 = 1
- u16 = 2
- u32 = 4
- u64 = 8
-
- @property
- def name(self) -> str:
- return str(self.value)
-
- @property
- def static_size(self) -> int:
- return self.value
-
-
-class BitfieldVal:
- name: str
- val: str
- ver: set[str]
-
- def __init__(self) -> None:
- self.ver = set()
-
-
-class Bitfield:
- name: str
- bits: list[str]
- names: dict[str, BitfieldVal]
-
- @property
- def static_size(self) -> int:
- return int((len(self.bits) + 7) / 8)
-
- def bitname_is_valid(self, bitname: str, ver: str | None = None) -> bool:
- assert bitname in self.bits
- if not bitname:
- return False
- if bitname.startswith("_"):
- return False
- if ver and (ver not in self.names[bitname].ver):
- return False
- return True
-
-
-# `msgid/structname = "member1 member2..."`
-# `structname = "member1 member2..."`
-# `structname += "member1 member2..."`
-class Struct:
- msgid: int | None = None
- msgver: set[str]
- name: str
- members: list["Member"]
-
- def __init__(self) -> None:
- self.msgver = set()
-
- @property
- def static_size(self) -> int | None:
- size = 0
- for member in self.members:
- msize = member.static_size
- if msize is None:
- return None
- size += msize
- return size
-
-
-class ExprVal:
- name: str
-
- def __init__(self, name: str) -> None:
- self.name = name
-
-
-class ExprOp:
- op: Literal["-", "+"]
-
- def __init__(self, op: Literal["-", "+"]) -> None:
- self.op = op
-
-
-# `cnt*(name[typ])`
-# the `cnt*(...)` wrapper is optional
-class Member:
- cnt: str | None = None
- name: str
- typ: Atom | Bitfield | Struct
- max: int | None = None
- valexpr: list[ExprVal | ExprOp] = []
- ver: set[str]
-
- @property
- def static_size(self) -> int | None:
- if self.cnt:
- return None
- return self.typ.static_size
-
-
-def parse_valexpr(valexpr: str) -> list[ExprVal | ExprOp]:
- ret: list[ExprVal | ExprOp] = []
- for tok in re.split("([-+])", valexpr):
- match tok:
- case "-":
- ret += [ExprOp(tok)]
- case "+":
- ret += [ExprOp(tok)]
- case _:
- ret += [ExprVal(tok)]
- return ret
-
-
-re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)"
-re_memberspec = f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>[^,]*)(?:,max=(?P<max>[0-9]+)|,val=(?P<val>[-+&a-zA-Z0-9_]+))*\\]\\)?"
-
-
-def parse_members(
- ver: str,
- env: dict[str, Atom | Bitfield | Struct],
- existing: list[Member],
- specs: str,
-) -> list[Member]:
- ret = existing
- for spec in specs.split():
- m = re.fullmatch(re_memberspec, spec)
- if not m:
- raise SyntaxError(f"invalid member spec {repr(spec)}")
-
- member = Member()
- member.ver = {ver}
-
- member.name = m.group("name")
- if any(x.name == member.name for x in ret):
- raise ValueError(f"duplicate member name {repr(member.name)}")
-
- if m.group("typ") not in env:
- raise NameError(f"Unknown type {repr(m.group(2))}")
- member.typ = env[m.group("typ")]
-
- if cnt := m.group("cnt"):
- if len(ret) == 0 or ret[-1].name != cnt:
- raise ValueError(f"list count must be previous item: {repr(cnt)}")
- if not isinstance(ret[-1].typ, Atom):
- raise ValueError(f"list count must be an integer type: {repr(cnt)}")
- member.cnt = cnt
-
- if maxstr := m.group("max"):
- if (not isinstance(member.typ, Atom)) or member.cnt:
- raise ValueError("',max=' may only be specified on a non-repeated atom")
- member.max = int(maxstr)
-
- if valstr := m.group("val"):
- if (not isinstance(member.typ, Atom)) or member.cnt:
- raise ValueError("',val=' may only be specified on a non-repeated atom")
- member.valexpr = parse_valexpr(valstr)
-
- ret += [member]
- return ret
-
-
-re_version = r'version\s+"(?P<version>[^"]+)"'
-re_import = r"from\s+(?P<file>\S+)\s+import\s+(?P<syms>\S+(?:\s*,\s*\S+)*)\s*"
-re_structspec = (
- r'(?:(?P<msgid>[0-9]+)/)?(?P<name>\S+)\s*(?P<op>\+?=)\s*"(?P<members>[^"]*)"'
-)
-re_structspec_cont = r'\s+"(?P<members>[^"]*)"'
-re_bitfieldspec = r"bitfield\s+(?P<name>\S+)\s+(?P<size>[0-9]+)"
-re_bitfieldspec_bit = r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<bit>[0-9]+)/(?P<name>\S+)"
-re_bitfieldspec_alias = (
- r"(?:\s+|(?P<bitfield>\S+)\s*\+=\s*)(?P<name>\S+)\s*=\s*(?P<val>.*)"
-)
-
-
-def parse_file(
- filename: str, get_include: Callable[[str], tuple[str, list[Bitfield | Struct]]]
-) -> tuple[str, list[Bitfield | Struct]]:
- version: str | None = None
- env: dict[str, Atom | Bitfield | Struct] = {
- "1": Atom.u8,
- "2": Atom.u16,
- "4": Atom.u32,
- "8": Atom.u64,
- }
- with open(filename, "r") as fh:
- prev: Struct | Bitfield | None = None
- for line in fh:
- line = line.split("#", 1)[0].rstrip()
- if not line:
- continue
- if m := re.fullmatch(re_version, line):
- if version:
- raise SyntaxError("must have exactly 1 version line")
- version = m.group("version")
- elif m := re.fullmatch(re_import, line):
- if not version:
- raise SyntaxError("must have exactly 1 version line")
- other_version, other_typs = get_include(m.group("file"))
- for symname in m.group("syms").split(sep=","):
- symname = symname.strip()
- for typ in other_typs:
- if typ.name == symname or symname == "*":
- match typ:
- case Bitfield():
- for val in typ.names.values():
- if other_version in val.ver:
- val.ver.add(version)
- case Struct():
- if typ.msgid:
- typ.msgver.add(version)
- for member in typ.members:
- if other_version in member.ver:
- member.ver.add(version)
- env[typ.name] = typ
- elif m := re.fullmatch(re_structspec, line):
- if not version:
- raise SyntaxError("must have exactly 1 version line")
- if m.group("op") == "+=" and m.group("msgid"):
- raise SyntaxError("cannot += to a message that is not yet defined")
- match m.group("op"):
- case "=":
- struct = Struct()
- if m.group("msgid"):
- struct.msgid = int(m.group("msgid"))
- struct.msgver.add(version)
- struct.name = m.group("name")
- struct.members = parse_members(
- version, env, [], m.group("members")
- )
- env[struct.name] = struct
- prev = struct
- case "+=":
- if m.group("name") not in env:
- raise NameError(f"Unknown type {repr(m.group('name'))}")
- _struct = env[m.group("name")]
- if not isinstance(_struct, Struct):
- raise NameError(
- f"Type {repr(_struct.name)} is not a struct"
- )
- struct = _struct
- struct.members = parse_members(
- version, env, struct.members, m.group("members")
- )
- prev = struct
- elif m := re.fullmatch(re_structspec_cont, line):
- if not isinstance(prev, Struct):
- raise SyntaxError(
- "struct-continuation line must come after a struct line"
- )
- assert version
- prev.members = parse_members(
- version, env, prev.members, m.group("members")
- )
- elif m := re.fullmatch(re_bitfieldspec, line):
- if not version:
- raise SyntaxError("must have exactly 1 version line")
- bf = Bitfield()
- bf.name = m.group("name")
- bf.bits = int(m.group("size")) * [""]
- bf.names = {}
- if len(bf.bits) not in [8, 16, 32, 64]:
- raise ValueError(f"Bitfield {repr(bf.name)} has an unusual size")
- env[bf.name] = bf
- prev = bf
- elif m := re.fullmatch(re_bitfieldspec_bit, line):
- if m.group("bitfield"):
- if m.group("bitfield") not in env:
- raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}")
- _bf = env[m.group("bitfield")]
- if not isinstance(_bf, Bitfield):
- raise NameError(f"Type {repr(_bf.name)} is not a bitfield")
- bf = _bf
- prev = bf
- else:
- if not isinstance(prev, Bitfield):
- raise SyntaxError(
- "bitfield-continuation line must come after a bitfield line"
- )
- bf = prev
- bit = int(m.group("bit"))
- name = m.group("name")
- if bit < 0 or bit >= len(bf.bits):
- raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
- if bf.bits[bit]:
- raise ValueError(f"{bf.name}: bit {bit} already assigned")
- if name in bf.names:
- raise ValueError(f"{bf.name}: name {name} already assigned")
-
- bf.bits[bit] = name
-
- assert version
- val = BitfieldVal()
- val.name = name
- val.val = f"1<<{bit}"
- val.ver.add(version)
- bf.names[name] = val
- elif m := re.fullmatch(re_bitfieldspec_alias, line):
- if m.group("bitfield"):
- if m.group("bitfield") not in env:
- raise NameError(f"Unknown bitfield {repr(m.group('bitfield'))}")
- _bf = env[m.group("bitfield")]
- if not isinstance(_bf, Bitfield):
- raise NameError(f"Type {repr(_bf.name)} is not a bitfield")
- bf = _bf
- prev = bf
- else:
- if not isinstance(prev, Bitfield):
- raise SyntaxError(
- "bitfield-continuation line must come after a bitfield line"
- )
- bf = prev
- name = m.group("name")
- valstr = m.group("val")
- if name in bf.names:
- raise ValueError(f"{bf.name}: name {name} already assigned")
-
- assert version
- val = BitfieldVal()
- val.name = name
- val.val = valstr
- val.ver.add(version)
- bf.names[name] = val
- else:
- raise SyntaxError(f"invalid line {repr(line)}")
- if not version:
- raise SyntaxError("must have exactly 1 version line")
-
- typs = [x for x in env.values() if not isinstance(x, Atom)]
-
- for typ in just_structs_all(typs):
- valid_vals = ["end", *["&" + m.name for m in typ.members]]
- for member in typ.members:
- for tok in member.valexpr:
- if isinstance(tok, ExprVal) and tok.name not in valid_vals:
- raise ValueError(
- f"{typ.name}.{member.name}: invalid val: {tok.name}"
- )
-
- return version, typs
-
-
-# Generate C ###################################################################
-
-
-def c_typename(idprefix: str, typ: Atom | Bitfield | Struct) -> str:
- match typ:
- case Atom():
- return f"uint{typ.value*8}_t"
- case Bitfield():
- return f"{idprefix}{typ.name}_t"
- case Struct():
- if typ.msgid is not None:
- return f"struct {idprefix}msg_{typ.name}"
- return f"struct {idprefix}{typ.name}"
- case _:
- raise ValueError(f"not a type: {typ.__class__.__name__}")
-
-
-def c_verenum(idprefix: str, ver: str) -> str:
- return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
-
-
-def c_vercomment(versions: set[str]) -> str | None:
- if "9P2000" in versions:
- return None
- return "/* " + (", ".join(sorted(versions))) + " */"
-
-
-def c_vercond(idprefix: str, versions: set[str]) -> str:
- if len(versions) == 1:
- return f"(ctx->ctx->version=={c_verenum(idprefix, next(v for v in versions))})"
- return (
- "( " + (" || ".join(c_vercond(idprefix, {v}) for v in sorted(versions))) + " )"
- )
-
-
-def just_structs_all(typs: list[Bitfield | Struct]) -> Sequence[Struct]:
- return list(typ for typ in typs if isinstance(typ, Struct))
-
-
-def just_structs_nonmsg(typs: list[Bitfield | Struct]) -> Sequence[Struct]:
- return list(typ for typ in typs if isinstance(typ, Struct) and typ.msgid is None)
-
-
-def just_structs_msg(typs: list[Bitfield | Struct]) -> Sequence[Struct]:
- return list(
- typ for typ in typs if isinstance(typ, Struct) and typ.msgid is not None
- )
-
-
-def just_bitfields(typs: list[Bitfield | Struct]) -> Sequence[Bitfield]:
- return list(typ for typ in typs if isinstance(typ, Bitfield))
-
-
-def gen_h(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str:
- ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
-
-#ifndef _LIB9P_9P_H_
-# error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
-#endif
-
-#include <stdint.h> /* for uint{{n}}_t types */
-"""
-
- ret += f"""
-/* versions *******************************************************************/
-
-enum {idprefix}version {{
-"""
- fullversions = ["unknown = 0", *sorted(versions)]
- verwidth = max(len(v) for v in fullversions)
- for ver in fullversions:
- ret += f"\t{c_verenum(idprefix, ver)},"
- ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
- ret += f"\t{c_verenum(idprefix, 'NUM')},\n"
- ret += "};\n"
- ret += "\n"
- ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"
-
- ret += """
-/* non-message types **********************************************************/
-"""
- for bf in just_bitfields(typs):
- ret += "\n"
- ret += f"typedef uint{bf.static_size*8}_t {c_typename(idprefix, bf)};\n"
- names = [
- *reversed([bf.bits[n] or f"_UNUSED_{n}" for n in range(0, len(bf.bits))]),
- "",
- *[k for k in bf.names if k not in bf.bits],
- ]
- namewidth = max(len(name) for name in names)
-
- ret += "\n"
- for name in names:
- if name == "":
- ret += "\n"
- continue
- if name.startswith("_"):
- cname = f"_{idprefix.upper()}{bf.name.upper()}_{name[1:]}"
- else:
- cname = f"{idprefix.upper()}{bf.name.upper()}_{name}"
- if name in bf.names:
- val = bf.names[name].val
- else:
- assert name.startswith("_UNUSED_")
- val = f"1<<{name[len('_UNUSED_'):]}"
- ret += f"#define {cname}{' '*(namewidth-len(name))} (({c_typename(idprefix, bf)})({val}))"
- if (name in bf.names) and (comment := c_vercomment(bf.names[name].ver)):
- ret += " " + comment
- ret += "\n"
-
- for struct in just_structs_nonmsg(typs):
- all_the_same = len(struct.members) == 0 or all(
- m.ver == struct.members[0].ver for m in struct.members
- )
- typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members)
- if not all_the_same:
- namewidth = max(len(m.name) for m in struct.members)
-
- ret += "\n"
- ret += c_typename(idprefix, struct) + " {\n"
- for member in struct.members:
- if member.valexpr:
- continue
- ctype = c_typename(idprefix, member.typ)
- if (struct.name in ["d", "s"]) and member.cnt: # SPECIAL
- ctype = "char"
- ret += f"\t{ctype.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};"
- if (not all_the_same) and (comment := c_vercomment(member.ver)):
- ret += (" " * (namewidth - len(member.name))) + " " + comment
- ret += "\n"
- ret += "};\n"
-
- ret += """
-/* messages *******************************************************************/
-
-"""
- ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
- namewidth = max(len(msg.name) for msg in just_structs_msg(typs))
- for msg in just_structs_msg(typs):
- ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},"
- if comment := c_vercomment(msg.msgver):
- ret += " " + comment
- ret += "\n"
- ret += "};\n"
- ret += "\n"
- ret += f"const char *{idprefix}msg_type_str(enum {idprefix}msg_type);\n"
-
- for msg in just_structs_msg(typs):
- ret += "\n"
- if comment := c_vercomment(msg.msgver):
- ret += comment + "\n"
- ret += c_typename(idprefix, msg) + " {"
- if not msg.members:
- ret += "};\n"
- continue
- ret += "\n"
-
- all_the_same = len(msg.members) == 0 or all(
- m.ver == msg.members[0].ver for m in msg.members
- )
- typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members)
- if not all_the_same:
- namewidth = max(len(m.name) for m in msg.members)
-
- for member in msg.members:
- if member.valexpr:
- continue
- ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};"
- if (not all_the_same) and (comment := c_vercomment(member.ver)):
- ret += (" " * (namewidth - len(member.name))) + " " + comment
- ret += "\n"
- ret += "};\n"
-
- return ret
-
-
-def gen_c(idprefix: str, versions: set[str], typs: list[Bitfield | Struct]) -> str:
- ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */
-
-#include <assert.h>
-#include <stdbool.h>
-#include <stddef.h> /* for size_t */
-#include <inttypes.h> /* for PRI* macros */
-#include <string.h> /* for memset() */
-
-#include <lib9p/9p.h>
-
-#include "internal.h"
-"""
-
- def used(arg: str) -> str:
- return arg
-
- def unused(arg: str) -> str:
- return f"UNUSED({arg})"
-
- # strings ##################################################################
- ret += f"""
-/* strings ********************************************************************/
-
-static const char *version_strs[{c_verenum(idprefix, 'NUM')}] = {{
-"""
- for ver in ["unknown", *sorted(versions)]:
- ret += f'\t[{c_verenum(idprefix, ver)}] = "{ver}",\n'
- ret += "};\n"
- ret += f"""
-const char *{idprefix}version_str(enum {idprefix}version ver) {{
- assert(0 <= ver && ver < {c_verenum(idprefix, 'NUM')});
- return version_strs[ver];
-}}
-
-static const char *msg_type_strs[0x100] = {{
-"""
- id2name: dict[int, str] = {}
- for msg in just_structs_msg(typs):
- assert msg.msgid
- id2name[msg.msgid] = msg.name
- for n in range(0, 0x100):
- ret += '\t[0x{:02X}] = "{}",\n'.format(n, id2name.get(n, "0x{:02X}".format(n)))
- ret += "};\n"
- ret += f"""
-const char *{idprefix}msg_type_str(enum {idprefix}msg_type typ) {{
- assert(0 <= typ && typ <= 0xFF);
- return msg_type_strs[typ];
-}}
-"""
-
- # validate_* ###############################################################
- ret += """
-/* validate_* *****************************************************************/
-
-static ALWAYS_INLINE bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
- if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
- /* If needed-net-size overflowed uint32_t, then
- * there's no way that actual-net-size will live up to
- * that. */
- return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
- if (ctx->net_offset > ctx->net_size)
- return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
- return false;
-}
-
-static ALWAYS_INLINE bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
- if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
- /* If needed-host-size overflowed size_t, then there's
- * no way that actual-net-size will live up to
- * that. */
- return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
- return false;
-}
-
-static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
- size_t cnt,
- _validate_fn_t item_fn, size_t item_host_size) {
- for (size_t i = 0; i < cnt; i++)
- if (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
- return true;
- return false;
-}
-
-#define validate_1(ctx) _validate_size_net(ctx, 1)
-#define validate_2(ctx) _validate_size_net(ctx, 2)
-#define validate_4(ctx) _validate_size_net(ctx, 4)
-#define validate_8(ctx) _validate_size_net(ctx, 8)
-"""
- for typ in typs:
- inline = (
- " FLATTEN"
- if (isinstance(typ, Struct) and typ.msgid is not None)
- else " ALWAYS_INLINE"
- )
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
- ret += "\n"
- ret += f"static{inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{"
-
- if typ.name == "d": # SPECIAL
- # Optimize... maybe the compiler could figure out to do
- # this, but let's make it obvious.
- ret += "\n"
- ret += "\tuint32_t base_offset = ctx->net_offset;\n"
- ret += "\tif (validate_4(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
- ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
- ret += "}\n"
- continue
- if typ.name == "s": # SPECIAL
- # Add an extra nul-byte on the host, and validate UTF-8
- # (also, similar optimization to "d").
- ret += "\n"
- ret += "\tuint32_t base_offset = ctx->net_offset;\n"
- ret += "\tif (validate_2(ctx))\n"
- ret += "\t\treturn true;\n"
- ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
- ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
- ret += "\t\treturn true;\n"
- ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
- ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
- ret += "\treturn false;\n"
- ret += "}\n"
- continue
-
- match typ:
- case Bitfield():
- ret += "\n"
- all_the_same = all(
- val.ver == [*typ.names.values()][0].ver
- for val in typ.names.values()
- )
- if (
- all_the_same
- and (len(typ.bits) == typ.static_size * 8)
- and all(typ.bitname_is_valid(bitname) for bitname in typ.bits)
- ):
- ret += f"\treturn validate_{typ.static_size}(ctx));\n"
- else:
- ret += f"\t if (validate_{typ.static_size}(ctx))\n"
- ret += "\t\treturn true;\n"
- if all_the_same:
- ret += (
- f"\tstatic const {c_typename(idprefix, typ)} mask = 0b"
- + "".join(
- "1" if typ.bitname_is_valid(bitname) else "0"
- for bitname in reversed(typ.bits)
- )
- + ";\n"
- )
- else:
- ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_verenum(idprefix, 'NUM')}] = {{\n"
- verwidth = max(len(ver) for ver in versions)
- for ver in sorted(versions):
- ret += (
- f"\t\t[{c_verenum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b"
- + "".join(
- "1" if typ.bitname_is_valid(bitname, ver) else "0"
- for bitname in reversed(typ.bits)
- )
- + ",\n"
- )
- ret += "\t};\n"
- ret += f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n"
- ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
- ret += f"\tif (val & ~mask)\n"
- ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n'
- ret += "\t\t val & ~mask);\n"
- ret += "\treturn false;\n"
- case Struct():
- if len(typ.members) == 0:
- ret += "\n\treturn false;\n"
- ret += "}\n"
- continue
-
- for member in typ.members:
- if member.max or member.valexpr:
- ret += f"\n\t{c_typename(idprefix, member.typ)} {member.name};"
- mark_offset: set[str] = set()
- for member in typ.members:
- for tok in member.valexpr:
- if (
- isinstance(tok, ExprVal)
- and tok.name.startswith("&")
- and tok.name[1:] not in mark_offset
- ):
- ret += f"\n\tuint32_t _{tok.name[1:]}_offset;"
- mark_offset.add(tok.name[1:])
-
- prefix0 = "\treturn "
- prefix1 = "\t || "
- prefix2 = "\t "
-
- struct_versions = typ.members[0].ver
-
- prefix = prefix0
- prev_size: int | None = None
- for member in typ.members:
- ret += f"\n{prefix}"
- if member.ver != struct_versions:
- ret += "( " + c_vercond(idprefix, member.ver) + " && "
- if member.cnt is not None:
- assert prev_size
- ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
- else:
- if member.max or member.valexpr:
- ret += "("
- if member.name in mark_offset:
- ret += f"({{ _{member.name}_offset = ctx->net_offset; "
- ret += f"validate_{member.typ.name}(ctx)"
- if member.name in mark_offset:
- ret += "; })"
- if member.max or member.valexpr:
- bytes = member.static_size
- assert bytes
- bits = bytes * 8
- ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))"
- if member.max:
- ret += f"\n{prefix1}"
- ret += f'({member.name} > UINT{bits}_C({member.max}) && lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu{bits}" > %"PRIu{bits}")", {member.name}, UINT{bits}_C({member.max})))'
- if member.ver != struct_versions:
- ret += " )"
- prefix = prefix1
- prev_size = member.static_size
-
- for member in typ.members:
- if member.valexpr:
- ret += f"\n{prefix}"
- ret += f"({{ uint32_t correct ="
- for tok in member.valexpr:
- match tok:
- case ExprOp():
- ret += f" {tok.op}"
- case ExprVal(name="end"):
- ret += " ctx->net_offset"
- case ExprVal():
- ret += f" _{tok.name[1:]}_offset"
- ret += f"; (((uint32_t){member.name}) != correct) &&"
- ret += f'\n{prefix2}lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, correct); }})'
-
- ret += ";\n"
- ret += "}\n"
-
- # unmarshal_* ##############################################################
- ret += """
-/* unmarshal_* ****************************************************************/
-
-static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
- *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 1;
-}
-
-static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
- *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 2;
-}
-
-static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
- *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 4;
-}
-
-static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
- *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 8;
-}
-"""
- for typ in typs:
- inline = (
- " FLATTEN"
- if (isinstance(typ, Struct) and typ.msgid is not None)
- else " ALWAYS_INLINE"
- )
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
- ret += "\n"
- ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n"
- match typ:
- case Bitfield():
- ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n"
- case Struct():
- ret += "\tmemset(out, 0, sizeof(*out));\n"
-
- if typ.members:
- struct_versions = typ.members[0].ver
- for member in typ.members:
- if member.valexpr:
- ret += f"\tctx->net_offset += {member.static_size};\n"
- continue
- ret += "\t"
- prefix = "\t"
- if member.ver != struct_versions:
- ret += "if ( " + c_vercond(idprefix, member.ver) + " ) "
- prefix = "\t\t"
- if member.cnt:
- if member.ver != struct_versions:
- ret += f"{{\n{prefix}"
- ret += f"out->{member.name} = ctx->extra;\n"
- ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
- ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
- if typ.name in ["d", "s"]: # SPECIAL
- ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
- else:
- ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
- if member.ver != struct_versions:
- ret += "\t}\n"
- else:
- ret += (
- f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
- )
- if typ.name == "s": # SPECIAL
- ret += "\tctx->extra++;\n"
- ret += "\tout->utf8[out->len] = '\\0';\n"
- ret += "}\n"
-
- # marshal_* ################################################################
- ret += """
-/* marshal_* ******************************************************************/
-
-static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) {
- lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
- (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
- ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
- ctx->ctx->max_msg_size);
- return true;
-}
-
-static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
- if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
- return _marshal_too_large(ctx);
- ctx->net_bytes[ctx->net_offset] = *val;
- ctx->net_offset += 1;
- return false;
-}
-
-static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
- if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
- return _marshal_too_large(ctx);
- encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 2;
- return false;
-}
-
-static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
- if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
- return true;
- encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 4;
- return false;
-}
-
-static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
- if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
- return true;
- encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
- ctx->net_offset += 8;
- return false;
-}
-"""
- for typ in typs:
- inline = (
- " FLATTEN"
- if (isinstance(typ, Struct) and typ.msgid is not None)
- else " ALWAYS_INLINE"
- )
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
- ret += "\n"
- ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{"
- match typ:
- case Bitfield():
- ret += "\n"
- ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n"
- case Struct():
- if len(typ.members) == 0:
- ret += "\n\treturn false;\n"
- ret += "}\n"
- continue
-
- mark_offset = set()
- for member in typ.members:
- if member.valexpr:
- if member.name not in mark_offset:
- ret += f"\n\tuint32_t _{member.name}_offset;"
- mark_offset.add(member.name)
- for tok in member.valexpr:
- if (
- isinstance(tok, ExprVal)
- and tok.name.startswith("&")
- and tok.name[1:] not in mark_offset
- ):
- ret += f"\n\tuint32_t _{tok.name[1:]}_offset;"
- mark_offset.add(tok.name[1:])
-
- prefix0 = "\treturn "
- prefix1 = "\t || "
- prefix2 = "\t "
-
- struct_versions = typ.members[0].ver
- prefix = prefix0
- for member in typ.members:
- ret += f"\n{prefix}"
- if member.ver != struct_versions:
- ret += "( " + c_vercond(idprefix, member.ver) + " && "
- if member.name in mark_offset:
- ret += f"({{ _{member.name}_offset = ctx->net_offset; "
- if member.cnt:
- ret += "({"
- ret += f"\n{prefix2}\tbool err = false;"
- ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
- if typ.name in ["d", "s"]: # SPECIAL
- ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);"
- else:
- ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
- ret += f"\n{prefix2}\terr;"
- ret += f"\n{prefix2}}})"
- elif member.valexpr:
- assert member.static_size
- ret += (
- f"({{ ctx->net_offset += {member.static_size}; false; }})"
- )
- else:
- ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
- if member.name in mark_offset:
- ret += "; })"
- if member.ver != struct_versions:
- ret += " )"
- prefix = prefix1
-
- for member in typ.members:
- if member.valexpr:
- assert member.static_size
- ret += f"\n{prefix}"
- ret += f"({{ encode_u{member.static_size*8}le("
- for tok in member.valexpr:
- match tok:
- case ExprOp():
- ret += f" {tok.op}"
- case ExprVal(name="end"):
- ret += " ctx->net_offset"
- case ExprVal():
- ret += f" _{tok.name[1:]}_offset"
- ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})"
-
- ret += ";\n"
- ret += "}\n"
-
- # vtables ##################################################################
- ret += f"""
-/* vtables ********************************************************************/
-
-#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\
- .basesize = sizeof(struct {idprefix}msg_##typ), \\
- .validate = validate_##typ, \\
- .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\
- .marshal = (_marshal_fn_t)marshal_##typ, \\
- }}
-
-struct _vtable_version _{idprefix}vtables[{c_verenum(idprefix, 'NUM')}] = {{
-"""
-
- ret += f"\t[{c_verenum(idprefix, 'unknown')}] = {{ .msgs = {{\n"
- for msg in just_structs_msg(typs):
- if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL
- ret += f"\t\t_MSG({msg.name}),\n"
- ret += "\t}},\n"
-
- for ver in sorted(versions):
- ret += f"\t[{c_verenum(idprefix, ver)}] = {{ .msgs = {{\n"
- for msg in just_structs_msg(typs):
- if ver not in msg.msgver:
- continue
- ret += f"\t\t_MSG({msg.name}),\n"
- ret += "\t}},\n"
- ret += "};\n"
-
- ############################################################################
- return ret
-
-
-################################################################################
-
-
-class Parser:
- cache: dict[str, tuple[str, list[Bitfield | Struct]]] = {}
-
- def parse_file(self, filename: str) -> tuple[str, list[Bitfield | Struct]]:
- filename = os.path.normpath(filename)
- if filename not in self.cache:
-
- def get_include(other_filename: str) -> tuple[str, list[Bitfield | Struct]]:
- return self.parse_file(os.path.join(filename, "..", other_filename))
-
- self.cache[filename] = parse_file(filename, get_include)
- return self.cache[filename]
-
- def all(self) -> tuple[set[str], list[Bitfield | Struct]]:
- ret_versions: set[str] = set()
- ret_typs: dict[str, Bitfield | Struct] = {}
- for version, typs in self.cache.values():
- if version in ret_versions:
- raise ValueError(f"duplicate protocol version {repr(version)}")
- ret_versions.add(version)
- for typ in typs:
- if typ.name in ret_typs:
- if typ != ret_typs[typ.name]:
- raise ValueError(f"duplicate type name {repr(typ.name)}")
- else:
- ret_typs[typ.name] = typ
- return ret_versions, list(ret_typs.values())
-
-
-if __name__ == "__main__":
- import sys
-
- if len(sys.argv) < 2:
- raise ValueError("requires at least 1 .txt filename")
- parser = Parser()
- for txtname in sys.argv[1:]:
- parser.parse_file(txtname)
- versions, typs = parser.all()
- outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
- with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
- fh.write(gen_h("lib9p_", versions, typs))
- with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
- fh.write(gen_c("lib9p_", versions, typs))