# lib9p/idl/__init__.py - A parser for .9p specification files. # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import enum import os.path import re from typing import Callable, Literal, TypeVar, cast __all__ = [ # entrypoint "Parser", # types "Type", "Primitive", "Number", *["Bitfield", "BitfieldVal"], *["Struct", "StructMember", "Expr", "ExprOp", "ExprSym", "ExprLit"], "Message", ] # The syntax that this parses is described in `./0000-README.md`. # Types ######################################################################## class Primitive(enum.Enum): u8 = 1 u16 = 2 u32 = 4 u64 = 8 @property def in_versions(self) -> set[str]: return set() @property def name(self) -> str: return str(self.value) @property def static_size(self) -> int: return self.value def min_size(self, version: str) -> int: return self.value def max_size(self, version: str) -> int: return self.value class Number: name: str in_versions: set[str] prim: Primitive vals: dict[str, str] def __init__(self) -> None: self.in_versions = set() self.vals = {} @property def static_size(self) -> int: return self.prim.static_size def min_size(self, version: str) -> int: return self.static_size def max_size(self, version: str) -> int: return self.static_size class BitfieldVal: name: str in_versions: set[str] val: str def __init__(self) -> None: self.in_versions = set() class Bitfield: name: str in_versions: set[str] prim: Primitive bits: list[str] # bitnames names: dict[str, BitfieldVal] # bits *and* aliases def __init__(self) -> None: self.in_versions = set() self.names = {} @property def static_size(self) -> int: return self.prim.static_size def min_size(self, version: str) -> int: return self.static_size def max_size(self, version: str) -> int: return self.static_size def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool: """Return whether the given bit is valid in the given protocol version. """ bitname = self.bits[bit] if isinstance(bit, int) else bit assert bitname in self.bits if not bitname: return False if bitname.startswith("_"): return False if ver and (ver not in self.names[bitname].in_versions): return False return True class ExprLit: val: int def __init__(self, val: int) -> None: self.val = val class ExprSym: name: str def __init__(self, name: str) -> None: self.name = name class ExprOp: op: Literal["-", "+"] def __init__(self, op: Literal["-", "+"]) -> None: self.op = op class Expr: tokens: list[ExprLit | ExprSym | ExprOp] def __init__(self) -> None: self.tokens = [] def __bool__(self) -> bool: return len(self.tokens) > 0 class StructMember: # from left-to-right when parsing cnt: "StructMember | None" = None name: str typ: "Type" max: Expr val: Expr in_versions: set[str] @property def min_cnt(self) -> int: assert self.cnt if not isinstance(self.cnt.typ, Primitive): raise ValueError( f"list count must be an integer type: {repr(self.cnt.name)}" ) if self.cnt.val: # TODO: allow this? raise ValueError(f"list count may not have ,val=: {repr(self.cnt.name)}") return 0 @property def max_cnt(self) -> int: assert self.cnt if not isinstance(self.cnt.typ, Primitive): raise ValueError( f"list count must be an integer type: {repr(self.cnt.name)}" ) if self.cnt.val: # TODO: allow this? raise ValueError(f"list count may not have ,val=: {repr(self.cnt.name)}") if self.cnt.max: # TODO: be more flexible? if len(self.cnt.max.tokens) != 1: raise ValueError( f"list count ,max= may only have 1 token: {repr(self.cnt.name)}" ) match tok := self.cnt.max.tokens[0]: case ExprLit(): return tok.val case ExprSym(name="s32_max"): return (1 << 31) - 1 case ExprSym(name="s64_max"): return (1 << 63) - 1 case _: raise ValueError( f'list count ,max= only allows literal, "s32_max", and "s64_max" tokens: {repr(self.cnt.name)}' ) return (1 << (self.cnt.typ.value * 8)) - 1 @property def static_size(self) -> int | None: if self.cnt: return None return self.typ.static_size def min_size(self, version: str) -> int: cnt = self.min_cnt if self.cnt else 1 return cnt * self.typ.min_size(version) def max_size(self, version: str) -> int: cnt = self.max_cnt if self.cnt else 1 return cnt * self.typ.max_size(version) class Struct: name: str in_versions: set[str] members: list[StructMember] def __init__(self) -> None: self.in_versions = set() @property def static_size(self) -> int | None: size = 0 for member in self.members: if member.in_versions < self.in_versions: return None msize = member.static_size if msize is None: return None size += msize return size def min_size(self, version: str) -> int: return sum( member.min_size(version) for member in self.members if (version in member.in_versions) ) def max_size(self, version: str) -> int: return sum( member.max_size(version) for member in self.members if (version in member.in_versions) ) class Message(Struct): @property def msgid(self) -> int: assert len(self.members) >= 3 assert self.members[1].name == "typ" assert self.members[1].static_size == 1 assert self.members[1].val assert len(self.members[1].val.tokens) == 1 assert isinstance(self.members[1].val.tokens[0], ExprLit) return self.members[1].val.tokens[0].val type Type = Primitive | Number | Bitfield | Struct | Message T = TypeVar("T", Number, Bitfield, Struct, Message) # Parse ######################################################################## re_priname = "(?:1|2|4|8)" # primitive names re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names re_impname = r"(?:\*|" + re_symname + ")" # names we can import re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)" re_numspec = f"(?P{re_symname})\\s*=\\s*(?P\\S+)" re_bitspec_bit = f"(?P[0-9]+)\\s*=\\s*(?P{re_symname})" re_bitspec_alias = f"(?P{re_symname})\\s*=\\s*(?P\\S+)" re_memberspec = f"(?:(?P{re_symname})\\*\\()?(?P{re_symname})\\[(?P{re_memtype})(?:,max=(?P{re_expr})|,val=(?P{re_expr}))*\\]\\)?" def parse_numspec(ver: str, n: Number, spec: str) -> None: spec = spec.strip() if m := re.fullmatch(re_numspec, spec): name = m.group("name") val = m.group("val") if name in n.vals: raise ValueError(f"{n.name}: name {repr(name)} already assigned") n.vals[name] = val else: raise SyntaxError(f"invalid num spec {repr(spec)}") def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None: spec = spec.strip() bit: int | None val: BitfieldVal if m := re.fullmatch(re_bitspec_bit, spec): bit = int(m.group("bit")) name = m.group("name") val = BitfieldVal() val.name = name val.val = f"1<<{bit}" val.in_versions.add(ver) if bit < 0 or bit >= len(bf.bits): raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") if bf.bits[bit]: raise ValueError(f"{bf.name}: bit {bit} already assigned") bf.bits[bit] = val.name elif m := re.fullmatch(re_bitspec_alias, spec): name = m.group("name") valstr = m.group("val") val = BitfieldVal() val.name = name val.val = valstr val.in_versions.add(ver) else: raise SyntaxError(f"invalid bitfield spec {repr(spec)}") if val.name in bf.names: raise ValueError(f"{bf.name}: name {val.name} already assigned") bf.names[val.name] = val def parse_expr(expr: str) -> Expr: assert re.fullmatch(re_expr, expr) ret = Expr() for tok in re.split("([-+])", expr): if tok == "-" or tok == "+": # I, for the life of me, do not understand why I need this # cast() to keep mypy happy. ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))] elif re.fullmatch("[0-9]+", tok): ret.tokens += [ExprLit(int(tok))] else: ret.tokens += [ExprSym(tok)] return ret def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None: for spec in specs.split(): m = re.fullmatch(re_memberspec, spec) if not m: raise SyntaxError(f"invalid member spec {repr(spec)}") member = StructMember() member.in_versions = {ver} member.name = m.group("name") if any(x.name == member.name for x in struct.members): raise ValueError(f"duplicate member name {repr(member.name)}") if m.group("typ") not in env: raise NameError(f"Unknown type {repr(m.group('typ'))}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): if len(struct.members) == 0 or struct.members[-1].name != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") cnt_mem = struct.members[-1] member.cnt = cnt_mem _ = member.max_cnt # force validation if maxstr := m.group("max"): if (not isinstance(member.typ, Primitive)) or member.cnt: raise ValueError("',max=' may only be specified on a non-repeated atom") member.max = parse_expr(maxstr) else: member.max = Expr() if valstr := m.group("val"): if (not isinstance(member.typ, Primitive)) or member.cnt: raise ValueError("',val=' may only be specified on a non-repeated atom") member.val = parse_expr(valstr) else: member.val = Expr() struct.members += [member] def re_string(grpname: str) -> str: return f'"(?P<{grpname}>[^"]*)"' re_line_version = f"version\\s+{re_string('version')}" re_line_import = f"from\\s+(?P\\S+)\\s+import\\s+(?P{re_impname}(?:\\s*,\\s*{re_impname})*)" re_line_num = f"num\\s+(?P{re_symname})\\s*=\\s*(?P{re_priname})" re_line_bitfield = f"bitfield\\s+(?P{re_symname})\\s*=\\s*(?P{re_priname})" re_line_bitfield_ = ( f"bitfield\\s+(?P{re_symname})\\s*\\+=\\s*{re_string('member')}" ) re_line_struct = ( f"struct\\s+(?P{re_symname})\\s*(?P\\+?=)\\s*{re_string('members')}" ) re_line_msg = ( f"msg\\s+(?P{re_msgname})\\s*(?P\\+?=)\\s*{re_string('members')}" ) re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg def parse_file( filename: str, get_include: Callable[[str], tuple[str, list[Type]]] ) -> tuple[str, list[Type]]: version: str | None = None env: dict[str, Type] = { "1": Primitive.u8, "2": Primitive.u16, "4": Primitive.u32, "8": Primitive.u64, } def get_type(name: str, tc: type[T]) -> T: nonlocal env if name not in env: raise NameError(f"Unknown type {repr(name)}") ret = env[name] if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}") return ret with open(filename, "r") as fh: prev: Type | None = None for lineno, line in enumerate(fh): try: line = line.split("#", 1)[0].rstrip() if not line: continue if m := re.fullmatch(re_line_version, line): if version: raise SyntaxError("must have exactly 1 version line") version = m.group("version") continue if not version: raise SyntaxError("must have exactly 1 version line") if m := re.fullmatch(re_line_import, line): other_version, other_typs = get_include(m.group("file")) for symname in m.group("syms").split(sep=","): symname = symname.strip() found = False for typ in other_typs: if typ.name == symname or symname == "*": found = True match typ: case Primitive(): pass case Number(): typ.in_versions.add(version) case Bitfield(): typ.in_versions.add(version) for val in typ.names.values(): if other_version in val.in_versions: val.in_versions.add(version) case Struct(): # and Message() typ.in_versions.add(version) for member in typ.members: if other_version in member.in_versions: member.in_versions.add(version) if typ.name in env and env[typ.name] != typ: raise ValueError( f"duplicate type name {repr(typ.name)}" ) env[typ.name] = typ if symname != "*" and not found: raise ValueError( f"import: {m.group('file')}: no symbol {repr(symname)}" ) elif m := re.fullmatch(re_line_num, line): num = Number() num.name = m.group("name") num.in_versions.add(version) prim = env[m.group("prim")] assert isinstance(prim, Primitive) num.prim = prim if num.name in env: raise ValueError(f"duplicate type name {repr(num.name)}") env[num.name] = num prev = num elif m := re.fullmatch(re_line_bitfield, line): bf = Bitfield() bf.name = m.group("name") bf.in_versions.add(version) prim = env[m.group("prim")] assert isinstance(prim, Primitive) bf.prim = prim bf.bits = (prim.static_size * 8) * [""] if bf.name in env: raise ValueError(f"duplicate type name {repr(bf.name)}") env[bf.name] = bf prev = bf elif m := re.fullmatch(re_line_bitfield_, line): bf = get_type(m.group("name"), Bitfield) parse_bitspec(version, bf, m.group("member")) prev = bf elif m := re.fullmatch(re_line_struct, line): match m.group("op"): case "=": struct = Struct() struct.name = m.group("name") struct.in_versions.add(version) struct.members = [] parse_members(version, env, struct, m.group("members")) if struct.name in env: raise ValueError( f"duplicate type name {repr(struct.name)}" ) env[struct.name] = struct prev = struct case "+=": struct = get_type(m.group("name"), Struct) parse_members(version, env, struct, m.group("members")) prev = struct elif m := re.fullmatch(re_line_msg, line): match m.group("op"): case "=": msg = Message() msg.name = m.group("name") msg.in_versions.add(version) msg.members = [] parse_members(version, env, msg, m.group("members")) if msg.name in env: raise ValueError( f"duplicate type name {repr(msg.name)}" ) env[msg.name] = msg prev = msg case "+=": msg = get_type(m.group("name"), Message) parse_members(version, env, msg, m.group("members")) prev = msg elif m := re.fullmatch(re_line_cont, line): match prev: case Bitfield(): parse_bitspec(version, prev, m.group("specs")) case Number(): parse_numspec(version, prev, m.group("specs")) case Struct(): # and Message() parse_members(version, env, prev, m.group("specs")) case _: raise SyntaxError( "continuation line must come after a bitfield, struct, or msg line" ) else: raise SyntaxError("invalid line") except (SyntaxError, NameError, ValueError) as e: e2 = SyntaxError(str(e)) e2.filename = filename e2.lineno = lineno + 1 e2.text = line raise e2 from e if not version: raise SyntaxError("must have exactly 1 version line") typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] for typ in [typ for typ in typs if isinstance(typ, Struct)]: valid_syms = ["end", "s32_max", "s64_max", *["&" + m.name for m in typ.members]] for member in typ.members: if ( not isinstance(member.typ, Primitive) and member.typ.in_versions < member.in_versions ): raise ValueError( f"{typ.name}.{member.name}: type {member.typ.name} does not exist in {member.in_versions.difference(member.typ.in_versions)}" ) for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, ExprSym) and tok.name not in valid_syms: raise ValueError( f"{typ.name}.{member.name}: invalid sym: {tok.name}" ) return version, typs # Filesystem ################################################################### class Parser: cache: dict[str, tuple[str, list[Type]]] = {} def parse_file(self, filename: str) -> tuple[str, list[Type]]: filename = os.path.normpath(filename) if filename not in self.cache: def get_include(other_filename: str) -> tuple[str, list[Type]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] def all(self) -> tuple[set[str], list[Type]]: ret_versions: set[str] = set() ret_typs: dict[str, Type] = {} for version, typs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {repr(version)}") ret_versions.add(version) for typ in typs: if typ.name in ret_typs: if typ != ret_typs[typ.name]: raise ValueError(f"duplicate type name {repr(typ.name)}") else: ret_typs[typ.name] = typ msgids: set[int] = set() for typ in ret_typs.values(): if isinstance(typ, Message): if typ.msgid in msgids: raise ValueError(f"duplicate msgid {repr(typ.msgid)}") msgids.add(typ.msgid) return ret_versions, list(ret_typs.values())