diff options
Diffstat (limited to '9p/generate')
-rwxr-xr-x | 9p/generate | 541 |
1 files changed, 541 insertions, 0 deletions
diff --git a/9p/generate b/9p/generate new file mode 100755 index 0000000..6456609 --- /dev/null +++ b/9p/generate @@ -0,0 +1,541 @@ +#!/usr/bin/env python +# 9p/generate - Generate C marshalers/unmarshalers for a .txt file +# defining a 9P protocol variant. +# +# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-Licence-Identifier: AGPL-3.0-or-later + +import enum +import re + +PROGRAM = "./9p/generate" + +# Parse the *.txt ############################################################## + + +class Atom(enum.Enum): + u8 = 1 + u16 = 2 + u32 = 4 + u64 = 8 + + @property + def name(self) -> str: + return str(self.value) + + @property + def static_size(self) -> int: + return self.value + + +# `msgid/structname = "member1 member2..."` +# `structname = "member1 member2..."` +# `structname += "member1 member2..."` +class Struct: + msgid: int | None = None + name: str + members: list["Member"] + + @property + def static_size(self) -> int | None: + size = 0 + for member in self.members: + msize = member.static_size + if msize is None: + return None + size += msize + return size + + +# `cnt*(name[typ])` +# the `cnt*(...)` wrapper is optional +class Member: + cnt: str | None = None + name: str + typ: Atom | Struct + + @property + def static_size(self) -> int | None: + if self.cnt: + return None + return self.typ.static_size + + +re_membername = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" +re_memberspec = ( + f"(?:(?P<cnt>{re_membername})\\*\\()?(?P<name>{re_membername})\\[(?P<typ>.*)\\]\\)?" +) + + +def parse_members( + env: dict[str, Atom | Struct], existing: list[Member], specs: str +) -> list[Member]: + ret = existing + for spec in specs.split(): + m = re.fullmatch(re_memberspec, spec) + if not m: + raise SyntaxError(f"invalid member spec {repr(spec)}") + + member = Member() + + member.name = m.group("name") + if any(x.name == member.name for x in ret): + raise ValueError(f"duplicate member name {repr(member.name)}") + + if m.group("typ") not in env: + raise NameError(f"Unknown type {repr(m.group(2))}") + member.typ = env[m.group("typ")] + + if cnt := m.group("cnt"): + if len(ret) == 0 or ret[-1].name != cnt: + raise ValueError(f"list count must be previous item: {repr(cnt)}") + if not isinstance(ret[-1].typ, Atom): + raise ValueError(f"list count must be an integer type: {repr(cnt)}") + member.cnt = cnt + + ret += [member] + return ret + + +re_version = r'version\s+"(?P<version>[^"]+)"' +re_structspec = ( + r'(?:(?P<msgid>[0-9]+)/)?(?P<name>\S+)\s*(?P<op>\+?=)\s*"(?P<members>[^"]*)"' +) +re_structspec_cont = r'"(?P<members>[^"]*)"' + + +def parse_file(filename: str) -> tuple[str, list[Struct]]: + version: str | None = None + env: dict[str, Atom | Struct] = { + "1": Atom.u8, + "2": Atom.u16, + "4": Atom.u32, + "8": Atom.u64, + } + with open(filename, "r") as fh: + prev: Struct | None = None + for line in fh: + line = line.split("#", 1)[0].strip() + if not line: + continue + if m := re.fullmatch(re_version, line): + if version: + raise SyntaxError("must have exactly 1 version line") + version = m.group("version") + elif m := re.fullmatch(re_structspec, line): + if m.group("op") == "+=" and m.group("msgid"): + raise SyntaxError("cannot += to a message that is not yet defined") + match m.group("op"): + case "=": + struct = Struct() + if m.group("msgid"): + struct.msgid = int(m.group("msgid")) + struct.name = m.group("name") + struct.members = parse_members(env, [], m.group("members")) + env[struct.name] = struct + prev = struct + case "+=": + if m.group("name") not in env: + raise NameError(f"Unknown type {repr(m.group('name'))}") + _struct = env[m.group("name")] + if not isinstance(_struct, Struct): + raise NameError( + f"Type {repr(m.group('name'))} is not a struct" + ) + struct = _struct + struct.members = parse_members( + env, struct.members, m.group("members") + ) + prev = struct + elif m := re.fullmatch(re_structspec_cont, line): + if not prev: + raise SyntaxError("continuation line must come after a struct line") + prev.members = parse_members(env, prev.members, m.group("members")) + else: + raise SyntaxError(f"invalid line {repr(line)}") + if not version: + raise SyntaxError("must have exactly 1 version line") + structs = [x for x in env.values() if isinstance(x, Struct)] + return version, structs + + +# Generate C ################################################################### + + +def c_typename(idprefix: str, typ: Atom | Struct) -> str: + match typ: + case Atom(): + return f"uint{typ.value*8}_t" + case Struct(): + if typ.msgid is not None: + return f"struct {idprefix}msg_{typ.name}" + return f"struct {idprefix}{typ.name}" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def gen_h(txtname: str, idprefix: str, structs: list[Struct]) -> str: + guard = ( + "_" + + txtname.replace(".txt", ".h").upper().replace("/", "_").replace(".", "_") + + "_" + ) + ret = f"""/* Generated by `{PROGRAM} {txtname}`. DO NOT EDIT! */ + +#ifndef {guard} +#define {guard} + +#define {idprefix.upper()}MIN_MSGLEN 7 +""" + ret += """ +/* non-message structs ********************************************************/ + +""" + for struct in structs: + if struct.msgid is not None: + continue + ret += c_typename(idprefix, struct) + " {\n" + typewidth = max( + len(c_typename(idprefix, member.typ)) for member in struct.members + ) + for member in struct.members: + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += "};\n" + + ret += """ +/* message types **************************************************************/ + +""" + ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + namewidth = max(len(msg.name) for msg in structs if msg.msgid is not None) + for msg in structs: + if msg.msgid is None: + continue + ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" + ret += "};\n" + + ret += """ +/* message structs ************************************************************/ + +""" + for msg in structs: + if msg.msgid is None: + continue + if not msg.members: + ret += c_typename(idprefix, msg) + " {};\n" + continue + ret += c_typename(idprefix, msg) + " {\n" + typewidth = max(len(c_typename(idprefix, member.typ)) for member in msg.members) + for member in msg.members: + ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += "};\n" + + ret += f""" +/* functions ******************************************************************/ + +/** + * @param net_bytes : the complete request, starting with the "size[4]" + * @param out_tag : the message-ID tag + * @param out_body : a pointer that get set to the parsed body, whose + * type is known by the return value, will need to + * be free()d + * @return -{idprefix.upper()}E{{error}} on error, {idprefix.upper()}TYP_{{type}} on success + */ +int {idprefix}unmarshal_msg(uint8_t *net_bytes, uint16_t *out_tag, void **out_body); + +/** + * @param uint16_t in_msgid : the message-ID tag + * @param struct {idprefix}msg_{{type}} in_msg : the message to encode + * @param uint8_t *out_buf : the buffer to encode to + * @return uint32_t : the encoded length + */ +#define {idprefix}marshal_msg(in_msgid, in_msg, out_buf) _Generic((in_msg)""" + for msg in structs: + if msg.msgid is None: + continue + ret += f", \\\n\t\t{c_typename(idprefix, msg)}: _{idprefix}marshal_{msg.name}(in_msgid, in_msg, out_buf)" + ret += "\\\n\t)(in_msg)\n" + for msg in structs: + if msg.msgid is None: + continue + ret += f"uint32_t _{idprefix}marshal_{msg.name}(uint16_t in_msgid, {c_typename(idprefix, msg)} in_msg, uint8_t *out_buf);\n" + + ret += "\n" + ret += f"#endif /* {guard} */\n" + return ret + + +def gen_c(txtname: str, idprefix: str, structs: list[Struct]) -> str: + ret = f"""/* Generated by `{PROGRAM} {txtname}`. DO NOT EDIT! */ + +#include <stdbool.h> +#include <stdint.h> +#include <stdlib.h> /* for malloc() */ + +#include "{txtname.replace('.txt', '.h')}" +""" + + # basic utilities ########################################################## + ret += """ +/* basic utilities ************************************************************/ + +#define UNUSED(name) /* name __attribute__ ((unused)) */ + +static inline uint8_t decode_u8le(uint8_t *in) { + return in[0]; +} +static inline uint16_t decode_u16le(uint8_t *in) { + return (((uint16_t)(in[0])) << 0) + | (((uint16_t)(in[1])) << 8) + ; +} +static inline uint32_t decode_u32le(uint8_t *in) { + return (((uint32_t)(in[0])) << 0) + | (((uint32_t)(in[1])) << 8) + | (((uint32_t)(in[2])) << 16) + | (((uint32_t)(in[3])) << 24) + ; +} +static inline uint64_t decode_u64le(uint8_t *in) { + return (((uint64_t)(in[0])) << 0) + | (((uint64_t)(in[1])) << 8) + | (((uint64_t)(in[2])) << 16) + | (((uint64_t)(in[3])) << 24) + | (((uint64_t)(in[4])) << 32) + | (((uint64_t)(in[5])) << 40) + | (((uint64_t)(in[6])) << 48) + | (((uint64_t)(in[7])) << 56) + ; +} + +static inline void encode_u8le(uint8_t in, uint8_t *out) { + out[0] = in; +} +static inline void encode_u16le(uint16_t in, uint8_t *out) { + out[0] = (uint8_t)((in >> 0) & 0xFF); + out[1] = (uint8_t)((in >> 8) & 0xFF); +} +static inline void encode_u32le(uint32_t in, uint8_t *out) { + out[0] = (uint8_t)((in >> 0) & 0xFF); + out[1] = (uint8_t)((in >> 8) & 0xFF); + out[2] = (uint8_t)((in >> 16) & 0xFF); + out[3] = (uint8_t)((in >> 24) & 0xFF); +} +static inline void encode_u64le(uint64_t in, uint8_t *out) { + out[0] = (uint8_t)((in >> 0) & 0xFF); + out[1] = (uint8_t)((in >> 8) & 0xFF); + out[2] = (uint8_t)((in >> 16) & 0xFF); + out[3] = (uint8_t)((in >> 24) & 0xFF); + out[4] = (uint8_t)((in >> 32) & 0xFF); + out[5] = (uint8_t)((in >> 40) & 0xFF); + out[6] = (uint8_t)((in >> 48) & 0xFF); + out[7] = (uint8_t)((in >> 56) & 0xFF); +} +""" + + def used(arg: str) -> str: + return arg + + def unused(arg: str) -> str: + return f"UNUSED({arg})" + + # checksize_* ############################################################## + ret += """ +/* checksize_* ****************************************************************/ + +typedef bool (*_checksize_fn_t)(uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, size_t *mut_host_extra); +static inline bool _checksize_list(size_t cnt, _checksize_fn_t fn, size_t host_size, + uint32_t net_len, uint8_t *net_bytes, uint32_t *mut_net_offset, size_t *mut_host_extra) { + for (size_t i = 0; i < cnt; i++) + if (__builtin_add_overflow(*mut_host_extra, host_size, mut_host_extra) + || fn(net_len, net_bytes, mut_net_offset, mut_host_extra)) + return true; + return false; +} +static inline bool checksize_1(uint32_t net_len, uint8_t *UNUSED(net_bytes), uint32_t *mut_net_offset, size_t *UNUSED(mut_host_extra)) { + return __builtin_add_overflow(*mut_net_offset, 1, mut_net_offset) || net_len < *mut_net_offset; +} +static inline bool checksize_2(uint32_t net_len, uint8_t *UNUSED(net_bytes), uint32_t *mut_net_offset, size_t *UNUSED(mut_host_extra)) { + return __builtin_add_overflow(*mut_net_offset, 2, mut_net_offset) || net_len < *mut_net_offset; +} +static inline bool checksize_4(uint32_t net_len, uint8_t *UNUSED(net_bytes), uint32_t *mut_net_offset, size_t *UNUSED(mut_host_extra)) { + return __builtin_add_overflow(*mut_net_offset, 4, mut_net_offset) || net_len < *mut_net_offset; +} +static inline bool checksize_8(uint32_t net_len, uint8_t *UNUSED(net_bytes), uint32_t *mut_net_offset, size_t *UNUSED(mut_host_extra)) { + return __builtin_add_overflow(*mut_net_offset, 8, mut_net_offset) || net_len < *mut_net_offset; +} +""" + for struct in structs: + argfn = used if struct.members else unused + ret += f"static inline bool checksize_{struct.name}(uint32_t {argfn('net_len')}, uint8_t *{argfn('net_bytes')}, uint32_t *{argfn('mut_net_offset')}, size_t *{argfn('mut_host_extra')}) {{\n" + if len(struct.members) == 0: + ret += "\treturn false;\n" + ret += "}\n" + continue + prefix0 = "\treturn " + prefix1 = "\t || " + prefix2 = "\t " + prefix = prefix0 + prev_size: int | None = None + for member in struct.members: + if member.cnt is not None: + assert prev_size + ret += f"\n{prefix }_checksize_list(decode_u{prev_size*8}le(&net_bytes[(*mut_net_offset)-{prev_size}]), checksize_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)})," + ret += f"\n{prefix2} net_len, net_bytes, mut_net_offset, mut_host_extra)" + else: + ret += f"\n{prefix}checksize_{member.typ.name}(net_len, net_bytes, mut_net_offset, mut_host_extra)" + prefix = prefix1 + prev_size = member.static_size + if struct.name == "s": + ret += ( + f"\n{prefix}__builtin_add_overflow(*mut_host_extra, 1, mut_host_extra)" + ) + ret += ";\n}\n" + + # unmarshal_* ############################################################## + ret += """ +/* unmarshal_* ****************************************************************/ +/* checksize_XXX() should be called before unmarshal_XXX(). */ + +static inline void unmarshal_1(uint8_t *net_bytes, uint32_t *mut_net_offset, void **UNUSED(mut_host_extra), uint8_t *out) { + *out = decode_u8le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 1; +} +static inline void unmarshal_2(uint8_t *net_bytes, uint32_t *mut_net_offset, void **UNUSED(mut_host_extra), uint16_t *out) { + *out = decode_u16le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 2; +} +static inline void unmarshal_4(uint8_t *net_bytes, uint32_t *mut_net_offset, void **UNUSED(mut_host_extra), uint32_t *out) { + *out = decode_u32le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 4; +} +static inline void unmarshal_8(uint8_t *net_bytes, uint32_t *mut_net_offset, void **UNUSED(mut_host_extra), uint64_t *out) { + *out = decode_u64le(&net_bytes[*mut_net_offset]); + *mut_net_offset += 8; +} +""" + for struct in structs: + argfn = used if struct.members else unused + ret += f"static inline void unmarshal_{struct.name}(uint8_t *{argfn('net_bytes')}, uint32_t *{argfn('mut_net_offset')}, void **{argfn('mut_host_extra')}, {c_typename(idprefix, struct)} *{argfn('out')}) {{" + if len(struct.members) == 0: + ret += "}\n" + continue + ret += "\n" + for member in struct.members: + if member.cnt: + ret += f"\tout->{member.name} = *mut_host_extra;\n" + ret += f"\t*mut_host_extra += sizeof(*out->{member.name}) * out->{member.cnt};\n" + ret += f"\tfor (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n" + ret += f"\t\tunmarshal_{member.typ.name}(net_bytes, mut_net_offset, mut_host_extra, &(out->{member.name}[i]));\n" + else: + ret += f"\tunmarshal_{member.typ.name}(net_bytes, mut_net_offset, mut_host_extra, &(out->{member.name}));\n" + ret += "}\n" + + # unmarshal_msg ############################################################ + ret += f""" +/* unmarshal_msg **************************************************************/ + +int {idprefix}unmarshal_msg(uint8_t *net_bytes, uint16_t *out_tag, void **out_body) {{ + uint32_t net_len = decode_u32le(net_bytes); + if (net_len < 7) + return -LINUX_EWRONGSIZE; + uint8_t typ = net_bytes[4]; + *out_tag = decode_u16le(&net_bytes[5]); + + uint32_t net_offset = 7; + size_t host_size; + void *host_extra; + switch (typ) {{ +""" + for msg in structs: + if msg.msgid is None: + continue + ret += f"\tcase {idprefix.upper()}TYP_{msg.name}:\n" + ret += f"\t\thost_size = sizeof({c_typename(idprefix, msg)});\n" + ret += f"\t\tif (checksize_{msg.name}(net_len, net_bytes, &net_offset, &host_size))\n" + ret += "\t\t\treturn -LINUX_EWRONGSIZE;\n" + ret += "\n" + ret += "\t\t*out_body = malloc(host_size);" + ret += "\n" + ret += "\t\tnet_offset = 7;\n" + ret += f"\t\thost_extra = *out_body + sizeof({c_typename(idprefix, msg)});\n" + ret += f"\t\tunmarshal_{msg.name}(net_bytes, &net_offset, &host_extra, *out_body);\n" + ret += "\n" + ret += "\t\tbreak;\n" + ret += """ + default: + return -LINUX_EOPNOTSUPP; + } + return typ; +} +""" + + # marshal_* ################################################################ + ret += """ +/* marshal_* ******************************************************************/ + +static inline void marshal_1(uint8_t val, uint8_t *out_net_bytes, uint32_t *mut_net_offset) { + out_net_bytes[*mut_net_offset] = val; + *mut_net_offset += 1; +} +static inline void marshal_2(uint16_t val, uint8_t *out_net_bytes, uint32_t *mut_net_offset) { + encode_u16le(val, &out_net_bytes[*mut_net_offset]); + *mut_net_offset += 2; +} +static inline void marshal_4(uint32_t val, uint8_t *out_net_bytes, uint32_t *mut_net_offset) { + encode_u32le(val, &out_net_bytes[*mut_net_offset]); + *mut_net_offset += 4; +} +static inline void marshal_8(uint64_t val, uint8_t *out_net_bytes, uint32_t *mut_net_offset) { + encode_u64le(val, &out_net_bytes[*mut_net_offset]); + *mut_net_offset += 8; +} +""" + for struct in structs: + argfn = used if struct.members else unused + ret += f"static inline void marshal_{struct.name}({c_typename(idprefix, struct)} {argfn('val')}, uint8_t *{argfn('out_net_bytes')}, uint32_t *{argfn('mut_net_offset')}) {{" + if len(struct.members) == 0: + ret += "}\n" + continue + ret += "\n" + for member in struct.members: + if member.cnt: + ret += f"\tfor (typeof(val.{member.cnt}) i = 0; i < val.{member.cnt}; i++)\n" + ret += f"\t\tmarshal_{member.typ.name}(val.{member.name}[i], out_net_bytes, mut_net_offset);\n" + else: + ret += f"\tmarshal_{member.typ.name}(val.{member.name}, out_net_bytes, mut_net_offset);\n" + ret += "}\n" + + # _marshal_msg_* ########################################################### + ret += """ +/* _marshal_msg_* *************************************************************/ + +""" + for msg in structs: + if msg.msgid is None: + continue + ret += f"uint32_t _{idprefix}marshal_{msg.name}(uint16_t in_msgid, {c_typename(idprefix, msg)} in_msg, uint8_t *out_buf) {{\n" + ret += "\tuint32_t offset = 4;\n" + ret += f"\tmarshal_1({idprefix.upper()}TYP_{msg.name}, out_buf, &offset);\n" + ret += "\tmarshal_2(in_msgid, out_buf, &offset);\n" + ret += f"\tmarshal_{msg.name}(in_msg, out_buf, &offset);\n" + ret += "\tencode_u32le(offset, out_buf);\n" + ret += "\treturn offset;\n" + ret += "}\n" + ret += "\n" + + ############################################################################ + return ret + + +################################################################################ + +if __name__ == "__main__": + import sys + + for txtname in sys.argv[1:]: + version, structs = parse_file(txtname) + with open(txtname.replace(".txt", ".h"), "w") as fh: + fh.write(gen_h(txtname, "p9_", structs)) + with open(txtname.replace(".txt", ".c"), "w") as fh: + fh.write(gen_c(txtname, "p9_", structs)) |