# lib9p/idl/__init__.py - A parser for .9p specification files. # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import enum import os.path import re import typing # pylint: disable=unused-variable __all__ = [ # entrypoint "Parser", # types "Type", "Primitive", *["Expr", "ExprTok", "ExprOp", "ExprLit", "ExprSym", "ExprOff", "ExprNum"], "Number", *["Bitfield", "Bit", "BitCat", "BitNum", "BitAlias"], *["Struct", "StructMember"], "Message", ] # The syntax that this parses is described in `./0000-README.md`. # Utilities #################################################################### def get_type(env: dict[str, "Type"], name: str, tc: type["T"]) -> "T": if name not in env: raise NameError(f"Unknown type {name!r}") ret = env[name] if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): raise NameError(f"Type {ret.typname!r} is not a {tc.__name__}") return ret # Types ######################################################################## class Primitive(enum.Enum): u8 = 1 u16 = 2 u32 = 4 u64 = 8 @property def in_versions(self) -> set[str]: return set() @property def typname(self) -> str: return str(self.value) @property def static_size(self) -> int: return self.value def min_size(self, version: str) -> int: return self.value def max_size(self, version: str) -> int: return self.value class ExprOp: op: typing.Literal["-", "+", "<<"] def __init__(self, op: typing.Literal["-", "+", "<<"]) -> None: self.op = op class ExprLit: val: int def __init__(self, val: int) -> None: self.val = val class ExprSym: symname: str def __init__(self, name: str) -> None: self.symname = name class ExprOff: membname: str def __init__(self, name: str) -> None: self.membname = name class ExprNum: numname: str valname: str def __init__(self, numname: str, valname: str) -> None: self.numname = numname self.valname = valname type ExprTok = ExprOp | ExprLit | ExprSym | ExprOff | ExprNum class Expr: tokens: typing.Sequence[ExprTok] const: int | None def __init__( self, env: dict[str, "Type"], tokens: typing.Sequence[ExprTok] = () ) -> None: self.tokens = tokens self.const = self._const(env, tokens) def _const( self, env: dict[str, "Type"], toks: typing.Sequence[ExprTok] ) -> int | None: if not toks: return None def read_val() -> int | None: nonlocal toks assert toks neg = False match toks[0]: case ExprOp(op="-"): neg = True toks = toks[1:] assert not isinstance(toks[0], ExprOp) val: int match toks[0]: case ExprLit(): val = toks[0].val case ExprSym(): if m := re.fullmatch(r"^u(8|16|32|64)_max$", toks[0].symname): n = int(m.group(1)) val = (1 << n) - 1 elif m := re.fullmatch(r"^s(8|16|32|64)_max$", toks[0].symname): n = int(m.group(1)) val = (1 << (n - 1)) - 1 else: return None case ExprOff(): return None case ExprNum(): num = get_type(env, toks[0].numname, Number) if toks[0].valname not in num.vals: raise NameError( f"Type {toks[0].numname!r} does not have a value {toks[0].valname!r}" ) _val = num.vals[toks[0].valname].const if _val is None: return None val = _val toks = toks[1:] return -val if neg else val ret = read_val() if ret is None: return None while toks: assert isinstance(toks[0], ExprOp) op = toks[0].op toks = toks[1:] operand = read_val() if operand is None: return None match op: case "+": ret = ret + operand case "-": ret = ret - operand case "<<": ret = ret << operand return ret def __bool__(self) -> bool: return len(self.tokens) > 0 class Number: typname: str in_versions: set[str] prim: Primitive vals: dict[str, Expr] def __init__(self) -> None: self.in_versions = set() self.vals = {} @property def static_size(self) -> int: return self.prim.static_size def min_size(self, version: str) -> int: return self.static_size def max_size(self, version: str) -> int: return self.static_size class BitAlias: bitname: str in_versions: set[str] val: Expr def __init__(self, name: str, val: Expr) -> None: if val.const is None: raise ValueError(f"{name!r} value is not constant") self.bitname = name self.in_versions = set() self.val = val class BitNum: numname: str mask: int vals: dict[str, BitAlias] def __init__(self, name: str) -> None: self.numname = name self.mask = 0 self.vals = {} type BitCat = typing.Literal["UNUSED", "USED", "RESERVED"] | BitNum class Bit: bitname: str in_versions: set[str] num: int cat: BitCat def __init__(self, num: int) -> None: self.bitname = "" self.in_versions = set() self.num = num self.cat = "UNUSED" class Bitfield: typname: str in_versions: set[str] prim: Primitive bits: list[Bit] nums: dict[str, BitNum] masks: dict[str, BitAlias] aliases: dict[str, BitAlias] names: set[str] def __init__(self, name: str, prim: Primitive) -> None: self.typname = name self.in_versions = set() self.prim = prim self.bits = [Bit(i) for i in range(prim.static_size * 8)] self.nums = {} self.masks = {} self.aliases = {} self.names = set() @property def static_size(self) -> int: return self.prim.static_size def min_size(self, version: str) -> int: return self.static_size def max_size(self, version: str) -> int: return self.static_size class StructMember: # from left-to-right when parsing cnt: "StructMember| int | None" = None membname: str typ: "Type" max: Expr val: Expr in_versions: set[str] @property def min_cnt(self) -> int: assert self.cnt if isinstance(self.cnt, int): return self.cnt if not isinstance(self.cnt.typ, Primitive): raise ValueError( f"list count must be an integer type: {self.cnt.membname!r}" ) if self.cnt.val: # TODO: allow this? raise ValueError(f"list count may not have ,val=: {self.cnt.membname!r}") return 0 @property def max_cnt(self) -> int: assert self.cnt if isinstance(self.cnt, int): return self.cnt if not isinstance(self.cnt.typ, Primitive): raise ValueError( f"list count must be an integer type: {self.cnt.membname!r}" ) if self.cnt.val: # TODO: allow this? raise ValueError(f"list count may not have ,val=: {self.cnt.membname!r}") if self.cnt.max: # TODO: be more flexible? val = self.cnt.max.const if val is None: raise ValueError( f"list count ,max= must be a constant value: {self.cnt.membname!r}" ) return val return (1 << (self.cnt.typ.value * 8)) - 1 @property def static_size(self) -> int | None: if self.cnt: return None return self.typ.static_size def min_size(self, version: str) -> int: cnt = self.min_cnt if self.cnt else 1 return cnt * self.typ.min_size(version) def max_size(self, version: str) -> int: cnt = self.max_cnt if self.cnt else 1 return cnt * self.typ.max_size(version) class Struct: typname: str in_versions: set[str] members: list[StructMember] def __init__(self) -> None: self.in_versions = set() @property def static_size(self) -> int | None: size = 0 for member in self.members: if member.in_versions < self.in_versions: return None msize = member.static_size if msize is None: return None size += msize return size def min_size(self, version: str) -> int: return sum( member.min_size(version) for member in self.members if (version in member.in_versions) ) def max_size(self, version: str) -> int: return sum( member.max_size(version) for member in self.members if (version in member.in_versions) ) class Message(Struct): @property def msgid(self) -> int: assert len(self.members) >= 3 assert self.members[1].membname == "typ" assert self.members[1].static_size == 1 assert self.members[1].val assert len(self.members[1].val.tokens) == 1 assert isinstance(self.members[1].val.tokens[0], ExprLit) return self.members[1].val.tokens[0].val type Type = Primitive | Number | Bitfield | Struct | Message type UserType = Number | Bitfield | Struct | Message T = typing.TypeVar("T", Number, Bitfield, Struct, Message) # Parse ######################################################################## # common elements ###################### re_priname = "(?:1|2|4|8)" # primitive names re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names re_symname_u = "(?:[A-Z_][A-Z_0-9]*)" # upper-case "symbol" names; bit names re_symname_l = "(?:[a-z_][a-z_0-9]*)" # lower-case "symbol" names; bit names re_impname = r"(?:\*|" + re_symname + ")" # names we can import re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be valid_syms = [ "end", "u8_max", "u16_max", "u32_max", "u64_max", "s8_max", "s16_max", "s32_max", "s64_max", ] _re_expr_op = r"(?:-|\+|<<)" _res_expr_val = { "lit_2": r"0b[01]+", "lit_8": r"0[0-7]+", "lit_10": r"0(?![0-9bxX])|[1-9][0-9]*", "lit_16": r"0[xX][0-9a-fA-F]+", "sym": "|".join(valid_syms), # pre-defined symbols "off": f"&{re_symname}", # offset of a field this struct "num": f"{re_symname}\\.{re_symname}", # `num` values } re_expr_tok = ( "(?:" + "|".join( [ f"(?P{_re_expr_op})", *[f"(?P<{k}>{v})" for k, v in _res_expr_val.items()], ] ) + ")" ) _re_expr_val = "(?:" + "|".join(_res_expr_val.values()) + ")" re_expr = f"(?:\\s*(?:-\\s*)?{_re_expr_val}\\s*(?:{_re_expr_op}\\s*(?:-\\s*)?{_re_expr_val}\\s*)*)" def parse_expr(env: dict[str, Type], expr: str) -> Expr: assert re.fullmatch(re_expr, expr) tokens: list[ExprTok] = [] for m in re.finditer(re_expr_tok, expr): if tok := m.group("op"): tokens.append(ExprOp(typing.cast(typing.Literal["-", "+", "<<"], tok))) elif tok := m.group("lit_2"): tokens.append(ExprLit(int(tok[2:], 2))) elif tok := m.group("lit_8"): tokens.append(ExprLit(int(tok[1:], 8))) elif tok := m.group("lit_10"): tokens.append(ExprLit(int(tok, 10))) elif tok := m.group("lit_16"): tokens.append(ExprLit(int(tok[2:], 16))) elif tok := m.group("sym"): tokens.append(ExprSym(tok)) elif tok := m.group("off"): tokens.append(ExprOff(tok[1:])) elif tok := m.group("num"): [numname, valname] = tok.split(".", 1) tokens.append(ExprNum(numname, valname)) else: assert False return Expr(env, tokens) # numspec ############################## re_numspec = f"(?P{re_symname})\\s*=\\s*(?P{re_expr})" def parse_numspec(env: dict[str, Type], ver: str, n: Number, spec: str) -> None: spec = spec.strip() if m := re.fullmatch(re_numspec, spec): name = m.group("name") if name in n.vals: raise ValueError(f"{n.typname}: name {name!r} already assigned") val = parse_expr(env, m.group("val")) if val is None: raise ValueError( f"{n.typname}: {name!r} value is not constant: {m.group('val')!r}" ) n.vals[name] = val else: raise SyntaxError(f"invalid num spec {spec!r}") # bitspec ############################## re_bitspec_bit = ( "bit\\s+(?P[0-9]+)\\s*=\\s*(?:" + "|".join( [ f"(?P{re_symname_u})", f"reserved\\((?P{re_symname_u})\\)", f"num\\((?P{re_symname_u})\\)", ] ) + ")" ) re_bitspec_mask = f"mask\\s+(?P{re_symname_u})\\s*=\\s*(?P{re_expr})" re_bitspec_alias = f"alias\\s+(?P{re_symname_u})\\s*=\\s*(?P{re_expr})" re_bitspec_num = f"num\\((?P{re_symname_u})\\)\\s+(?P{re_symname_u})\\s*=\\s*(?P{re_expr})" def parse_bitspec(env: dict[str, Type], ver: str, bf: Bitfield, spec: str) -> None: spec = spec.strip() def check_name(name: str, is_num: bool = False) -> None: if name == "MASK": raise ValueError(f"{bf.typname}: bit name may not be {'MASK'!r}: {name!r}") if name.endswith("_MASK"): raise ValueError( f"{bf.typname}: bit name may not end with {'_MASK'!r}: {name!r}" ) if name in bf.names and not (is_num and name in bf.nums): raise ValueError(f"{bf.typname}: bit name already assigned: {name!r}") if m := re.fullmatch(re_bitspec_bit, spec): bitnum = int(m.group("bitnum")) if bitnum < 0 or bitnum >= len(bf.bits): raise ValueError(f"{bf.typname}: bit num {bitnum} out-of-bounds") bit = bf.bits[bitnum] if bit.cat != "UNUSED": raise ValueError(f"{bf.typname}: bit num {bitnum} already assigned") if name := m.group("name_used"): bit.bitname = name bit.cat = "USED" bit.in_versions.add(ver) elif name := m.group("name_reserved"): bit.bitname = name bit.cat = "RESERVED" bit.in_versions.add(ver) elif name := m.group("name_num"): bit.bitname = name if name not in bf.nums: bf.nums[name] = BitNum(name) bf.nums[name].mask |= 1 << bit.num bit.cat = bf.nums[name] bit.in_versions.add(ver) if bit.bitname: check_name(name, isinstance(bit.cat, BitNum)) bf.names.add(bit.bitname) elif m := re.fullmatch(re_bitspec_mask, spec): mask = BitAlias(m.group("name"), parse_expr(env, m.group("val"))) mask.in_versions.add(ver) check_name(mask.bitname) bf.masks[mask.bitname] = mask bf.names.add(mask.bitname) elif m := re.fullmatch(re_bitspec_alias, spec): alias = BitAlias(m.group("name"), parse_expr(env, m.group("val"))) alias.in_versions.add(ver) check_name(alias.bitname) bf.aliases[alias.bitname] = alias bf.names.add(alias.bitname) elif m := re.fullmatch(re_bitspec_num, spec): numname = m.group("num") alias = BitAlias(m.group("name"), parse_expr(env, m.group("val"))) alias.in_versions.add(ver) check_name(alias.bitname) if numname not in bf.nums: raise NameError( f"{bf.typname}: nested num not allocated any bits: {numname!r}" ) assert alias.val.const is not None if alias.val.const & ~bf.nums[numname].mask: raise ValueError( f"{bf.typname}: {alias.bitname!r} does not fit within bitmask: val={alias.val.const:b} mask={bf.nums[numname].mask}" ) bf.nums[numname].vals[alias.bitname] = alias bf.names.add(alias.bitname) else: raise SyntaxError(f"invalid bitfield spec {spec!r}") # struct members ####################### re_memberspec = f"(?:(?P{re_symname}|[1-9][0-9]*)\\*\\()?(?P{re_symname})\\[(?P{re_memtype})(?:,max=(?P{re_expr})|,val=(?P{re_expr}))*\\]\\)?" def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None: for spec in specs.split(): m = re.fullmatch(re_memberspec, spec) if not m: raise SyntaxError(f"invalid member spec {spec!r}") member = StructMember() member.in_versions = {ver} member.membname = m.group("name") if any(x.membname == member.membname for x in struct.members): raise ValueError(f"duplicate member name {member.membname!r}") if m.group("typ") not in env: raise NameError(f"Unknown type {m.group('typ')!r}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): if cnt.isnumeric(): member.cnt = int(cnt) else: if len(struct.members) == 0 or struct.members[-1].membname != cnt: raise ValueError(f"list count must be previous item: {cnt!r}") member.cnt = struct.members[-1] _ = member.max_cnt # force validation if maxstr := m.group("max"): if ( not isinstance(member.typ, Primitive) and not isinstance(member.typ, Number) ) or member.cnt: raise ValueError( "',max=' may only be specified on a non-repeated numeric type" ) member.max = parse_expr(env, maxstr) else: member.max = Expr(env) if valstr := m.group("val"): if ( not isinstance(member.typ, Primitive) and not isinstance(member.typ, Number) ) or member.cnt: raise ValueError( "',val=' may only be specified on a non-repeated numeric type" ) member.val = parse_expr(env, valstr) else: member.val = Expr(env) struct.members += [member] # main parser ########################## def re_string(grpname: str) -> str: return f'"(?P<{grpname}>[^"]*)"' re_line_version = f"version\\s+{re_string('version')}" re_line_import = f"from\\s+(?P\\S+)\\s+import\\s+(?P{re_impname}(?:\\s*,\\s*{re_impname})*)" re_line_num = f"num\\s+(?P{re_symname})\\s*=\\s*(?P{re_priname})" re_line_bitfield = f"bitfield\\s+(?P{re_symname})\\s*=\\s*(?P{re_priname})" re_line_bitfield_ = ( f"bitfield\\s+(?P{re_symname})\\s*\\+=\\s*{re_string('member')}" ) re_line_struct = ( f"struct\\s+(?P{re_symname})\\s*(?P\\+?=)\\s*{re_string('members')}" ) re_line_msg = ( f"msg\\s+(?P{re_msgname})\\s*(?P\\+?=)\\s*{re_string('members')}" ) re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg def parse_file( filename: str, get_include: typing.Callable[[str], tuple[str, list[UserType]]] ) -> tuple[str, list[UserType]]: version: str | None = None env: dict[str, Type] = { "1": Primitive.u8, "2": Primitive.u16, "4": Primitive.u32, "8": Primitive.u64, } with open(filename, "r", encoding="utf-8") as fh: prev: Type | None = None for lineno, line in enumerate(fh): try: line = line.split("#", 1)[0].rstrip() if not line: continue if m := re.fullmatch(re_line_version, line): if version: raise SyntaxError("must have exactly 1 version line") version = m.group("version") continue if not version: raise SyntaxError("must have exactly 1 version line") if m := re.fullmatch(re_line_import, line): other_version, other_typs = get_include(m.group("file")) for symname in m.group("syms").split(sep=","): symname = symname.strip() found = False for typ in other_typs: if symname in (typ.typname, "*"): found = True match typ: case Primitive(): pass case Number(): typ.in_versions.add(version) case Bitfield(): typ.in_versions.add(version) for bf_bit in typ.bits: if other_version in bf_bit.in_versions: bf_bit.in_versions.add(version) for bf_num in typ.nums.values(): for bf_val in bf_num.vals.values(): if other_version in bf_val.in_versions: bf_val.in_versions.add(version) for bf_mask in typ.masks.values(): if other_version in bf_mask.in_versions: bf_mask.in_versions.add(version) for bf_alias in typ.aliases.values(): if other_version in bf_alias.in_versions: bf_alias.in_versions.add(version) case Struct(): # and Message() typ.in_versions.add(version) for member in typ.members: if other_version in member.in_versions: member.in_versions.add(version) if typ.typname in env and env[typ.typname] != typ: raise ValueError( f"duplicate type name {typ.typname!r}" ) env[typ.typname] = typ if symname != "*" and not found: raise ValueError( f"import: {m.group('file')}: no symbol {symname!r}" ) elif m := re.fullmatch(re_line_num, line): num = Number() num.typname = m.group("name") num.in_versions.add(version) prim = env[m.group("prim")] assert isinstance(prim, Primitive) num.prim = prim if num.typname in env: raise ValueError(f"duplicate type name {num.typname!r}") env[num.typname] = num prev = num elif m := re.fullmatch(re_line_bitfield, line): prim = env[m.group("prim")] assert isinstance(prim, Primitive) bf = Bitfield(m.group("name"), prim) bf.in_versions.add(version) if bf.typname in env: raise ValueError(f"duplicate type name {bf.typname!r}") env[bf.typname] = bf prev = bf elif m := re.fullmatch(re_line_bitfield_, line): bf = get_type(env, m.group("name"), Bitfield) parse_bitspec(env, version, bf, m.group("member")) prev = bf elif m := re.fullmatch(re_line_struct, line): match m.group("op"): case "=": struct = Struct() struct.typname = m.group("name") struct.in_versions.add(version) struct.members = [] parse_members(version, env, struct, m.group("members")) if struct.typname in env: raise ValueError( f"duplicate type name {struct.typname!r}" ) env[struct.typname] = struct prev = struct case "+=": struct = get_type(env, m.group("name"), Struct) parse_members(version, env, struct, m.group("members")) prev = struct elif m := re.fullmatch(re_line_msg, line): match m.group("op"): case "=": msg = Message() msg.typname = m.group("name") msg.in_versions.add(version) msg.members = [] parse_members(version, env, msg, m.group("members")) if msg.typname in env: raise ValueError(f"duplicate type name {msg.typname!r}") env[msg.typname] = msg prev = msg case "+=": msg = get_type(env, m.group("name"), Message) parse_members(version, env, msg, m.group("members")) prev = msg elif m := re.fullmatch(re_line_cont, line): match prev: case Bitfield(): parse_bitspec(env, version, prev, m.group("specs")) case Number(): parse_numspec(env, version, prev, m.group("specs")) case Struct(): # and Message() parse_members(version, env, prev, m.group("specs")) case _: raise SyntaxError( "continuation line must come after a bitfield, struct, or msg line" ) else: raise SyntaxError("invalid line") except (SyntaxError, NameError, ValueError) as e: e2 = SyntaxError(str(e)) e2.filename = filename e2.lineno = lineno + 1 e2.text = line raise e2 from e if not version: raise SyntaxError("must have exactly 1 version line") typs: list[UserType] = [x for x in env.values() if not isinstance(x, Primitive)] for typ in [typ for typ in typs if isinstance(typ, Struct)]: for member in typ.members: if ( not isinstance(member.typ, Primitive) and member.typ.in_versions < member.in_versions ): raise ValueError( f"{typ.typname}.{member.membname}: type {member.typ.typname} does not exist in {member.in_versions.difference(member.typ.in_versions)}" ) for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, ExprOff) and not any( m.membname == tok.membname for m in typ.members ): raise NameError( f"{typ.typname}.{member.membname}: invalid offset: &{tok.membname}" ) return version, typs # Filesystem ################################################################### class Parser: cache: dict[str, tuple[str, list[UserType]]] = {} def parse_file(self, filename: str) -> tuple[str, list[UserType]]: filename = os.path.normpath(filename) if filename not in self.cache: def get_include(other_filename: str) -> tuple[str, list[UserType]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] def all(self) -> tuple[set[str], list[UserType]]: ret_versions: set[str] = set() ret_typs: dict[str, UserType] = {} for version, typs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {version!r}") ret_versions.add(version) for typ in typs: if typ.typname in ret_typs: if typ != ret_typs[typ.typname]: raise ValueError(f"duplicate type name {typ.typname!r}") else: ret_typs[typ.typname] = typ msgids: set[int] = set() for typ in ret_typs.values(): if isinstance(typ, Message): if typ.msgid in msgids: raise ValueError(f"duplicate msgid {typ.msgid!r}") msgids.add(typ.msgid) return ret_versions, list(ret_typs.values())