diff options
Diffstat (limited to 'lib9p/idl')
-rw-r--r-- | lib9p/idl/__init__.py | 491 |
1 files changed, 491 insertions, 0 deletions
diff --git a/lib9p/idl/__init__.py b/lib9p/idl/__init__.py new file mode 100644 index 0000000..920d02d --- /dev/null +++ b/lib9p/idl/__init__.py @@ -0,0 +1,491 @@ +# lib9p/idl/__init__.py - A parser for .9p specification files. +# +# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import enum +import os.path +import re +from typing import Callable, Literal, TypeAlias, TypeVar, cast + +__all__ = [ + # entrypoint + "Parser", + # types + "Type", + "Primitive", + "Number", + *["Bitfield", "BitfieldVal"], + *["Struct", "StructMember", "Expr", "ExprOp", "ExprSym", "ExprLit"], + "Message", +] + +# The syntax that this parses is described in `./0000-README.md`. + +# Types ######################################################################## + + +class Primitive(enum.Enum): + u8 = 1 + u16 = 2 + u32 = 4 + u64 = 8 + + @property + def in_versions(self) -> set[str]: + return set() + + @property + def name(self) -> str: + return str(self.value) + + @property + def static_size(self) -> int: + return self.value + + +class Number: + name: str + in_versions: set[str] + + prim: Primitive + + def __init__(self) -> None: + self.in_versions = set() + + @property + def static_size(self) -> int: + return self.prim.static_size + + +class BitfieldVal: + name: str + in_versions: set[str] + + val: str + + def __init__(self) -> None: + self.in_versions = set() + + +class Bitfield: + name: str + in_versions: set[str] + + prim: Primitive + + bits: list[str] # bitnames + names: dict[str, BitfieldVal] # bits *and* aliases + + def __init__(self) -> None: + self.in_versions = set() + self.names = {} + + @property + def static_size(self) -> int: + return self.prim.static_size + + def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool: + """Return whether the given bit is valid in the given protocol + version. + + """ + bitname = self.bits[bit] if isinstance(bit, int) else bit + assert bitname in self.bits + if not bitname: + return False + if bitname.startswith("_"): + return False + if ver and (ver not in self.names[bitname].in_versions): + return False + return True + + +class ExprLit: + val: int + + def __init__(self, val: int) -> None: + self.val = val + + +class ExprSym: + name: str + + def __init__(self, name: str) -> None: + self.name = name + + +class ExprOp: + op: Literal["-", "+"] + + def __init__(self, op: Literal["-", "+"]) -> None: + self.op = op + + +class Expr: + tokens: list[ExprLit | ExprSym | ExprOp] + + def __init__(self) -> None: + self.tokens = [] + + def __bool__(self) -> bool: + return len(self.tokens) > 0 + + +class StructMember: + # from left-to-right when parsing + cnt: str | None = None + name: str + typ: "Type" + max: Expr + val: Expr + + in_versions: set[str] + + @property + def static_size(self) -> int | None: + if self.cnt: + return None + return self.typ.static_size + + +class Struct: + name: str + in_versions: set[str] + + members: list[StructMember] + + def __init__(self) -> None: + self.in_versions = set() + + @property + def static_size(self) -> int | None: + size = 0 + for member in self.members: + msize = member.static_size + if msize is None: + return None + size += msize + return size + + +class Message(Struct): + @property + def msgid(self) -> int: + assert len(self.members) >= 3 + assert self.members[1].name == "typ" + assert self.members[1].static_size == 1 + assert self.members[1].val + assert len(self.members[1].val.tokens) == 1 + assert isinstance(self.members[1].val.tokens[0], ExprLit) + return self.members[1].val.tokens[0].val + + +Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message +# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13 +T = TypeVar("T", Number, Bitfield, Struct, Message) + +# Parse ######################################################################## + +re_priname = "(?:1|2|4|8)" # primitive names +re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names +re_impname = r"(?:\*|" + re_symname + ")" # names we can import +re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be + +re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be + +re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)" + +re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})" +re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)" + +re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?" + + +def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None: + spec = spec.strip() + + bit: int | None + val: BitfieldVal + if m := re.fullmatch(re_bitspec_bit, spec): + bit = int(m.group("bit")) + name = m.group("name") + + val = BitfieldVal() + val.name = name + val.val = f"1<<{bit}" + val.in_versions.add(ver) + + if bit < 0 or bit >= len(bf.bits): + raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") + if bf.bits[bit]: + raise ValueError(f"{bf.name}: bit {bit} already assigned") + bf.bits[bit] = val.name + elif m := re.fullmatch(re_bitspec_alias, spec): + name = m.group("name") + valstr = m.group("val") + + val = BitfieldVal() + val.name = name + val.val = valstr + val.in_versions.add(ver) + else: + raise SyntaxError(f"invalid bitfield spec {repr(spec)}") + + if val.name in bf.names: + raise ValueError(f"{bf.name}: name {val.name} already assigned") + bf.names[val.name] = val + + +def parse_expr(expr: str) -> Expr: + assert re.fullmatch(re_expr, expr) + ret = Expr() + for tok in re.split("([-+])", expr): + if tok == "-" or tok == "+": + # I, for the life of me, do not understand why I need this + # cast() to keep mypy happy. + ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))] + elif re.fullmatch("[0-9]+", tok): + ret.tokens += [ExprLit(int(tok))] + else: + ret.tokens += [ExprSym(tok)] + return ret + + +def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None: + for spec in specs.split(): + m = re.fullmatch(re_memberspec, spec) + if not m: + raise SyntaxError(f"invalid member spec {repr(spec)}") + + member = StructMember() + member.in_versions = {ver} + + member.name = m.group("name") + if any(x.name == member.name for x in struct.members): + raise ValueError(f"duplicate member name {repr(member.name)}") + + if m.group("typ") not in env: + raise NameError(f"Unknown type {repr(m.group(2))}") + member.typ = env[m.group("typ")] + + if cnt := m.group("cnt"): + if len(struct.members) == 0 or struct.members[-1].name != cnt: + raise ValueError(f"list count must be previous item: {repr(cnt)}") + if not isinstance(struct.members[-1].typ, Primitive): + raise ValueError(f"list count must be an integer type: {repr(cnt)}") + member.cnt = cnt + + if maxstr := m.group("max"): + if (not isinstance(member.typ, Primitive)) or member.cnt: + raise ValueError("',max=' may only be specified on a non-repeated atom") + member.max = parse_expr(maxstr) + else: + member.max = Expr() + + if valstr := m.group("val"): + if (not isinstance(member.typ, Primitive)) or member.cnt: + raise ValueError("',val=' may only be specified on a non-repeated atom") + member.val = parse_expr(valstr) + else: + member.val = Expr() + + struct.members += [member] + + +def re_string(grpname: str) -> str: + return f'"(?P<{grpname}>[^"]*)"' + + +re_line_version = f"version\\s+{re_string('version')}" +re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)" +re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})" +re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})" +re_line_bitfield_ = ( + f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}" +) +re_line_struct = ( + f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}" +) +re_line_msg = ( + f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}" +) +re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg + + +def parse_file( + filename: str, get_include: Callable[[str], tuple[str, list[Type]]] +) -> tuple[str, list[Type]]: + version: str | None = None + env: dict[str, Type] = { + "1": Primitive.u8, + "2": Primitive.u16, + "4": Primitive.u32, + "8": Primitive.u64, + } + + def get_type(name: str, tc: type[T]) -> T: + nonlocal env + if name not in env: + raise NameError(f"Unknown type {repr(name)}") + ret = env[name] + if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): + raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}") + return ret + + with open(filename, "r") as fh: + prev: Type | None = None + for line in fh: + line = line.split("#", 1)[0].rstrip() + if not line: + continue + if m := re.fullmatch(re_line_version, line): + if version: + raise SyntaxError("must have exactly 1 version line") + version = m.group("version") + continue + if not version: + raise SyntaxError("must have exactly 1 version line") + + if m := re.fullmatch(re_line_import, line): + other_version, other_typs = get_include(m.group("file")) + for symname in m.group("syms").split(sep=","): + symname = symname.strip() + for typ in other_typs: + if typ.name == symname or symname == "*": + match typ: + case Primitive(): + pass + case Number(): + typ.in_versions.add(version) + case Bitfield(): + typ.in_versions.add(version) + for val in typ.names.values(): + if other_version in val.in_versions: + val.in_versions.add(version) + case Struct(): # and Message() + typ.in_versions.add(version) + for member in typ.members: + if other_version in member.in_versions: + member.in_versions.add(version) + env[typ.name] = typ + elif m := re.fullmatch(re_line_num, line): + num = Number() + num.name = m.group("name") + num.in_versions.add(version) + + prim = env[m.group("prim")] + assert isinstance(prim, Primitive) + num.prim = prim + + env[num.name] = num + prev = num + elif m := re.fullmatch(re_line_bitfield, line): + bf = Bitfield() + bf.name = m.group("name") + bf.in_versions.add(version) + + prim = env[m.group("prim")] + assert isinstance(prim, Primitive) + bf.prim = prim + + bf.bits = (prim.static_size * 8) * [""] + + env[bf.name] = bf + prev = bf + elif m := re.fullmatch(re_line_bitfield_, line): + bf = get_type(m.group("name"), Bitfield) + parse_bitspec(version, bf, m.group("member")) + + prev = bf + elif m := re.fullmatch(re_line_struct, line): + match m.group("op"): + case "=": + struct = Struct() + struct.name = m.group("name") + struct.in_versions.add(version) + struct.members = [] + parse_members(version, env, struct, m.group("members")) + + env[struct.name] = struct + prev = struct + case "+=": + struct = get_type(m.group("name"), Struct) + parse_members(version, env, struct, m.group("members")) + + prev = struct + elif m := re.fullmatch(re_line_msg, line): + match m.group("op"): + case "=": + msg = Message() + msg.name = m.group("name") + msg.in_versions.add(version) + msg.members = [] + parse_members(version, env, msg, m.group("members")) + + env[msg.name] = msg + prev = msg + case "+=": + msg = get_type(m.group("name"), Message) + parse_members(version, env, msg, m.group("members")) + + prev = msg + elif m := re.fullmatch(re_line_cont, line): + match prev: + case Bitfield(): + parse_bitspec(version, prev, m.group("specs")) + case Struct(): # and Message() + parse_members(version, env, prev, m.group("specs")) + case _: + raise SyntaxError( + "continuation line must come after a bitfield, struct, or msg line" + ) + else: + raise SyntaxError(f"invalid line {repr(line)}") + if not version: + raise SyntaxError("must have exactly 1 version line") + + typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] + + for typ in [typ for typ in typs if isinstance(typ, Struct)]: + valid_syms = ["end", *["&" + m.name for m in typ.members]] + for member in typ.members: + for tok in [*member.max.tokens, *member.val.tokens]: + if isinstance(tok, ExprSym) and tok.name not in valid_syms: + raise ValueError( + f"{typ.name}.{member.name}: invalid sym: {tok.name}" + ) + + return version, typs + + +# Filesystem ################################################################### + + +class Parser: + cache: dict[str, tuple[str, list[Type]]] = {} + + def parse_file(self, filename: str) -> tuple[str, list[Type]]: + filename = os.path.normpath(filename) + if filename not in self.cache: + + def get_include(other_filename: str) -> tuple[str, list[Type]]: + return self.parse_file(os.path.join(filename, "..", other_filename)) + + self.cache[filename] = parse_file(filename, get_include) + return self.cache[filename] + + def all(self) -> tuple[set[str], list[Type]]: + ret_versions: set[str] = set() + ret_typs: dict[str, Type] = {} + for version, typs in self.cache.values(): + if version in ret_versions: + raise ValueError(f"duplicate protocol version {repr(version)}") + ret_versions.add(version) + for typ in typs: + if typ.name in ret_typs: + if typ != ret_typs[typ.name]: + raise ValueError(f"duplicate type name {repr(typ.name)}") + else: + ret_typs[typ.name] = typ + return ret_versions, list(ret_typs.values()) |