From e5e15c04bc58c34906e6d7cfcbad68d1a5617563 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Fri, 27 Sep 2024 17:25:36 -0600 Subject: wip --- lib9p/types.gen | 611 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100755 lib9p/types.gen (limited to 'lib9p/types.gen') diff --git a/lib9p/types.gen b/lib9p/types.gen new file mode 100755 index 0000000..141b591 --- /dev/null +++ b/lib9p/types.gen @@ -0,0 +1,611 @@ +#!/usr/bin/env python +# lib9p/types.gen - Generate C marshalers/unmarshalers for .txt files +# defining 9P protocol variants. +# +# Copyright (C) 2024 Luke T. Shumaker +# SPDX-Licence-Identifier: AGPL-3.0-or-later + +import enum +import os.path +import re +from typing import Callable + +# Parse *.txt ################################################################## + + +class Atom(enum.Enum): + u8 = 1 + u16 = 2 + u32 = 4 + u64 = 8 + + @property + def name(self) -> str: + return str(self.value) + + @property + def static_size(self) -> int: + return self.value + + +# `msgid/structname = "member1 member2..."` +# `structname = "member1 member2..."` +# `structname += "member1 member2..."` +class Struct: + msgid: int | None = None + msgver: set[str] + name: str + members: list["Member"] + + def __init__(self) -> None: + self.msgver = set() + + @property + def static_size(self) -> int | None: + size = 0 + for member in self.members: + msize = member.static_size + if msize is None: + return None + size += msize + return size + + +# `cnt*(name[typ])` +# the `cnt*(...)` wrapper is optional +class Member: + cnt: str | None = None + name: str + typ: Atom | Struct + ver: set[str] + + @property + def static_size(self) -> int | None: + if self.cnt: + return None + return self.typ.static_size + + +re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" +re_memberspec = ( + f"(?:(?P{re_membername})\\*\\()?(?P{re_membername})\\[(?P.*)\\]\\)?" +) + + +def parse_members( + ver: str, + env: dict[str, Atom | Struct], existing: list[Member], specs: str +) -> list[Member]: + ret = existing + for spec in specs.split(): + m = re.fullmatch(re_memberspec, spec) + if not m: + raise SyntaxError(f"invalid member spec {repr(spec)}") + + member = Member() + member.ver = {ver} + + member.name = m.group("name") + if any(x.name == member.name for x in ret): + raise ValueError(f"duplicate member name {repr(member.name)}") + + if m.group("typ") not in env: + raise NameError(f"Unknown type {repr(m.group(2))}") + member.typ = env[m.group("typ")] + + if cnt := m.group("cnt"): + if len(ret) == 0 or ret[-1].name != cnt: + raise ValueError(f"list count must be previous item: {repr(cnt)}") + if not isinstance(ret[-1].typ, Atom): + raise ValueError(f"list count must be an integer type: {repr(cnt)}") + member.cnt = cnt + + ret += [member] + return ret + + +re_version = r'version\s+"(?P[^"]+)"' +re_import = r'from\s+(?P\S+)\s+import\s+(?P\S+(?:\s*,\s*\S+)*)\s*' +re_structspec = ( + r'(?:(?P[0-9]+)/)?(?P\S+)\s*(?P\+?=)\s*"(?P[^"]*)"' +) +re_structspec_cont = r'"(?P[^"]*)"' + + +def parse_file(filename: str, get_include: Callable[[str], tuple[str, list[Struct]]]) -> tuple[str, list[Struct]]: + version: str | None = None + env: dict[str, Atom | Struct] = { + "1": Atom.u8, + "2": Atom.u16, + "4": Atom.u32, + "8": Atom.u64, + } + with open(filename, "r") as fh: + prev: Struct | None = None + for line in fh: + line = line.split("#", 1)[0].strip() + if not line: + continue + if m := re.fullmatch(re_version, line): + if version: + raise SyntaxError("must have exactly 1 version line") + version = m.group("version") + elif m := re.fullmatch(re_import, line): + if not version: + raise SyntaxError("must have exactly 1 version line") + other_version, other_structs = get_include(m.group("file")) + for symname in m.group("syms").split(sep=","): + symname = symname.strip() + for struct in other_structs: + if struct.name == symname or symname == '*': + if struct.msgid: + struct.msgver.add(version) + for member in struct.members: + if other_version in member.ver: + member.ver.add(version) + env[struct.name] = struct + elif m := re.fullmatch(re_structspec, line): + if not version: + raise SyntaxError("must have exactly 1 version line") + if m.group("op") == "+=" and m.group("msgid"): + raise SyntaxError("cannot += to a message that is not yet defined") + match m.group("op"): + case "=": + struct = Struct() + if m.group("msgid"): + struct.msgid = int(m.group("msgid")) + struct.msgver.add(version) + struct.name = m.group("name") + struct.members = parse_members(version, env, [], m.group("members")) + env[struct.name] = struct + prev = struct + case "+=": + if m.group("name") not in env: + raise NameError(f"Unknown type {repr(m.group('name'))}") + _struct = env[m.group("name")] + if not isinstance(_struct, Struct): + raise NameError( + f"Type {repr(m.group('name'))} is not a struct" + ) + struct = _struct + struct.members = parse_members(version, + env, struct.members, m.group("members") + ) + prev = struct + elif m := re.fullmatch(re_structspec_cont, line): + if not prev: + raise SyntaxError("continuation line must come after a struct line") + assert(version) + prev.members = parse_members(version, env, prev.members, m.group("members")) + else: + raise SyntaxError(f"invalid line {repr(line)}") + if not version: + raise SyntaxError("must have exactly 1 version line") + structs = [x for x in env.values() if isinstance(x, Struct)] + return version, structs + +# Generate C ################################################################### + + +def c_typename(idprefix: str, typ: Atom | Struct) -> str: + match typ: + case Atom(): + return f"uint{typ.value*8}_t" + case Struct(): + if typ.msgid is not None: + return f"struct {idprefix}msg_{typ.name}" + return f"struct {idprefix}{typ.name}" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + +def c_vercomment(versions: set[str]) -> str|None: + if "9P2000" in versions: + return None + return "/* "+(", ".join(sorted(versions)))+" */" + +def c_ver(idprefix: str, ver: str) -> str: + return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" + +def gen_h(idprefix: str, versions: set[str], structs: list[Struct]) -> str: + guard = "_LIB9P__TYPES_H_" + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#ifndef {guard} +#define {guard} + +#include +""" + + ret += f""" +/* versions *******************************************************************/ + +enum {idprefix}version {{ + {idprefix.upper()}VER_UNINITIALIZED = 0, +""" + verwidth = max(len(v) for v in versions) + for ver in sorted(versions): + ret += f"\t{c_ver(idprefix, ver)}," + ret += (" "*(verwidth-len(ver))) + ' /* "' + ver + '" */\n' + ret += f"\t{idprefix.upper()}VER_NUM,\n" + ret += "};\n" + + ret += """ +/* non-message structs ********************************************************/ +""" + for struct in structs: + if struct.msgid is not None: + continue + + all_the_same = len(struct.members) == 0 or all(m.ver == struct.members[0].ver for m in struct.members) + typewidth = max(len(c_typename(idprefix, m.typ)) for m in struct.members) + if not all_the_same: + namewidth = max(len(m.name) for m in struct.members) + + ret += "\n" + ret += c_typename(idprefix, struct) + " {\n" + for member in struct.members: + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" + if (not all_the_same) and (comment := c_vercomment(member.ver)): + ret += (" "*(namewidth-len(member.name))) + " " + comment + ret += "\n" + ret += "};\n" + + ret += """ +/* messages *******************************************************************/ + +""" + ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + namewidth = max(len(msg.name) for msg in structs if msg.msgid is not None) + for msg in structs: + if msg.msgid is None: + continue + ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid}," + if (comment := c_vercomment(msg.msgver)): + ret += " " + comment + ret += "\n" + ret += "};\n" + + ret += "\n" + ret += f"#define {idprefix.upper()}TYPECODE_FOR_CTYPE(msg) _Generic((msg)" + for msg in structs: + if msg.msgid is None: + continue + ret += f", \\\n\t\t{c_typename(idprefix, msg)}: {idprefix.upper()}TYP_{msg.name}" + ret += ")\n" + + for msg in structs: + if msg.msgid is None: + continue + + ret += "\n" + if comment := c_vercomment(msg.msgver): + ret += comment + "\n" + ret += c_typename(idprefix, msg) + " {" + if not msg.members: + ret += "};\n" + continue + ret += "\n" + + all_the_same = len(msg.members) == 0 or all(m.ver == msg.members[0].ver for m in msg.members) + typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members) + if not all_the_same: + namewidth = max(len(m.name) for m in msg.members) + + for member in msg.members: + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};" + if (not all_the_same) and (comment := c_vercomment(member.ver)): + ret += (" "*(namewidth-len(member.name))) + " " + comment + ret += "\n" + ret += "};\n" + + + ret += "\n" + ret += f"#endif /* {guard} */\n" + return ret + + +def gen_c(idprefix: str, versions: set[str], structs: list[Struct]) -> str: + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#include + +#include + +#include "internal.h" +""" + + def used(arg: str) -> str: + return arg + + def unused(arg: str) -> str: + return f"UNUSED({arg})" + + # checksize_* ############################################################## + ret += """ +/* checksize_* (internals of unmarshal_size()) ********************************/ + +static inline bool _checksize_net(struct _checksize_ctx *ctx, uint32_t n) { + if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) + /* If needed-net-size overflowed uint32_t, then + * there's no way that actual-net-size will live up to + * that. */ + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + if (ctx->net_offset > ctx->net_size) + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + return false; +} + +static inline bool _checksize_host(struct _checksize_ctx *ctx, size_t n) { + if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) + /* If needed-host-size overflowed size_t, then there's + * no way that actual-net-size will live up to + * that. */ + return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); + return false; +} + +static inline bool _checksize_list(struct _checksize_ctx *ctx, + size_t cnt, _checksize_fn_t item_fn, size_t item_host_size) { + for (size_t i = 0; i < cnt; i++) + if (_checksize_host(ctx, item_host_size) || item_fn(ctx)) + return true; + return false; +} + +#define checksize_1(ctx) _checksize_net(ctx, 1) +#define checksize_2(ctx) _checksize_net(ctx, 2) +#define checksize_4(ctx) _checksize_net(ctx, 4) +#define checksize_8(ctx) _checksize_net(ctx, 8) +""" + for struct in structs: + inline = ' inline' if struct.msgid is None else '' + argfn = used if struct.members else unused + ret += "\n" + ret += f"static{inline} bool checksize_{struct.name}(struct _checksize_ctx *{argfn('ctx')}) {{" + if len(struct.members) == 0: + ret += "\n\treturn false;\n" + ret += "}\n" + continue + + prefix0 = "\treturn " + prefix1 = "\t || " + + match struct.name: + case "d": + # Optimize... maybe the compiler could figure out to do + # this, but let's make it obvious. + ret += "\n" + ret += "\tuint32_t base_offset = ctx->net_offset;\n" + ret += "\tif (_checksize_4(ctx))\n" + ret += "\t\treturn true;\n" + ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n" + ret += "\treturn checksize_net(ctx, len) || checksize_host(ctx, len);\n" + ret += "}\n" + case "s": + # Add an extra nul-byte on the host, and validate + # UTF-8 (also, similar optimization to "d"). + ret += "\n" + ret += "\tuint32_t base_offset = ctx->net_offset;\n" + ret += "\tif (_checksize_2(ctx))\n" + ret += "\t\treturn true;\n" + ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n" + ret += "\tif (checksize_net(ctx, len) || checksize_host(ctx, ((size_t)len)+1))\n" + ret += "\t\treturn true;\n" + ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n" + ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n' + ret += "\treturn false;\n" + ret += "}\n" + case _: + struct_versions = struct.members[0].ver + + prefix = prefix0 + prev_size: int | None = None + for member in struct.members: + ret += f"\n{prefix}" + if member.ver != struct_versions: + ret += "( ( " + (" || ".join(f"(ctx->ctx->version=={c_ver(idprefix, v)})" for v in sorted(member.ver))) + " ) && " + if member.cnt is not None: + assert prev_size + ret += f"_checksize_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), checksize_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))" + else: + ret += f"checksize_{member.typ.name}(ctx)" + if member.ver != struct_versions: + ret += " )" + prefix = prefix1 + prev_size = member.static_size + ret += ";\n}\n" + + # unmarshal_* ############################################################## + ret += """ +/* unmarshal_* ****************************************************************/ + +static inline vold unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { + *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 1; +} + +static inline vold unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { + *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 2; +} + +static inline vold unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { + *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 4; +} + +static inline vold unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { + *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]); + ctx->net_offset += 8; +} +""" + for struct in structs: + inline = ' inline' if struct.msgid is None else '' + argfn = used if struct.members else unused + ret += "\n" + ret += f"static{inline} void unmarshal_{struct.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *{argfn('out')}) {{\n" + ret += "\tmemset(out, 0, sizeof(*out));\n" + + if struct.members: + struct_versions = struct.members[0].ver + for member in struct.members: + ret += "\t" + prefix = "\t" + if member.ver != struct_versions: + ret += "if ( " + (" || ".join(f"(ctx->ctx->version=={c_ver(idprefix, v)})" for v in sorted(member.ver))) + " ) " + prefix = "\t\t" + if member.cnt: + if member.ver != struct_versions: + ret += "{\n" + ret += f"{prefix}out->{member.name} = ctx->host_extra;\n" + ret += f"{prefix}ctx->host_extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n" + ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + if member.ver != struct_versions: + ret += "\t}\n" + else: + ret += f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" + ret += "}\n" + + # marshal_* ################################################################ + ret += """ +/* marshal_* ******************************************************************/ + +static inline bool _marshal_too_large(struct _marshal_ctx *ctx) { + lib9p_errorf(ctx->ctx, "%s too large to marshal into %s limit (limit=%"PRIu32")", + (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", + ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), + ctx->ctx->max_msg_size)); + return true; +} + +static inline bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { + if (ctx->net_offset + 1 > ctx->max_msg_size) + return _marshal_too_large(ctx); + out_net_bytes[ctx->net_offset] = *val; + ctx->net_offset += 1; + return false; +} + +static inline bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { + if (ctx->net_offset + 2 > ctx->max_msg_size) + return _marshal_too_large(ctx); + encode_u16le(*val, &out_net_bytes[ctx->net_offset]); + ctx->net_offset += 2; + return false; +} + +static inline bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { + if (ctx->net_offset + 4 > ctx->max_msg_size) + return true; + encode_u32le(*val, &out_net_bytes[ctx->net_offset]); + ctx->net_offset += 4; + return false; +} + +static inline bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { + if (ctx->net_offset + 8 > ctx->max_msg_size) + return true; + encode_u64le(*val, &out_net_bytes[ctx->net_offset]); + ctx->net_offset += 8; + return false; +} +""" + for struct in structs: + inline = ' inline' if struct.msgid is None else '' + argfn = used if struct.members else unused + ret += "\n" + ret += f"static{inline} bool marshal_{struct.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, struct)} *{argfn('val')}) {{" + if len(struct.members) == 0: + ret += "\n\treturn false;\n" + ret += "}\n" + continue + + prefix0 = "\treturn " + prefix1 = "\t || " + prefix2 = "\t " + + prefix = prefix0 + for member in struct.members: + if member.cnt: + ret += f"\n{prefix }({{" + ret += f"\n{prefix2}\tbool err = false;" + ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)" + ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]));" + ret += f"\n{prefix2}\terr;" + ret += f"\n{prefix2}}})" + else: + ret += f"\n{prefix}marshal_{member.typ.name}(ctx, &val->{member.name})" + prefix = prefix1 + ret += ";\n}\n" + + # vtables ################################################################## + def msg_entry(msg: Struct) -> str: + ret = "" + ret += f"\t\t[{idprefix.upper()}TYP_{msg.name}] = {{\n" + ret += f"\t\t\tr.unmarshal_basesize = sizeof({c_typename(idprefix, msg)}),\n" + ret += f"\t\t\tr.unmarshal_extrasize = checksize_{msg.name},\n" + ret += f"\t\t\tr.unmarshal = unmarshal_{msg.name},\n" + ret += f"\t\t\tr.marshal = (_marshal_fn_t)marshal_{msg.name},\n" + ret += "\t\t}" + return ret + ret += f""" +/* vtables ********************************************************************/ + +struct _vtable_version _{idprefix}vtables[LIB9P_VER_NUM] = {{ + [{idprefix.upper()}VER_UNINITIALIZED] = {{ +{msg_entry(next(msg for msg in structs if msg.name == 'Tversion'))}, +{msg_entry(next(msg for msg in structs if msg.name == 'Rversion'))}, + }}, +""" + for ver in sorted(versions): + ret += f"\t[{c_ver(idprefix, ver)}] = {{\n" + for msg in structs: + if ver not in msg.msgver: + continue + ret += msg_entry(msg) + ",\n" + ret += "\t},\n" + ret += "};\n" + + ############################################################################ + return ret + + +################################################################################ + +class Parser: + cache: dict[str, tuple[str, list[Struct]]] = {} + + def parse_file(self, filename: str) -> tuple[str, list[Struct]]: + if filename not in self.cache: + self.cache[filename] = parse_file(filename, self.parse_file) + return self.cache[filename] + + def all(self) -> tuple[set[str], list[Struct]]: + ret_versions: set[str] = set() + ret_structs: dict[str, Struct] = {} + for (version, structs) in self.cache.values(): + if version in ret_versions: + raise ValueError(f"duplicate protocol version {repr(version)}") + ret_versions.add(version) + for struct in structs: + if struct.name in ret_structs: + if struct != ret_structs[struct.name]: + raise ValueError(f"duplicate struct name {repr(struct.name)}") + else: + ret_structs[struct.name] = struct + return ret_versions, list(ret_structs.values()) + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 2: + raise ValueError("requires at least 1 .txt filename") + parser = Parser() + for txtname in sys.argv[1:]: + parser.parse_file(txtname) + versions, structs = parser.all() + with open("include/lib9p/_types.h", "w") as fh: + fh.write(gen_h("lib9p_", versions, structs)) + with open("types.c", "w") as fh: + fh.write(gen_c("lib9p_", versions, structs)) -- cgit v1.2.3-2-g168b