summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-12-20 01:18:37 -0700
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-12-26 17:48:44 -0700
commitd6bac0f1ca6489e1d532c4e1a5fb3e6b774aafd7 (patch)
tree6334ec564f54fb1a416a7983354e1e49299f8081 /lib9p/idl.gen
parent4eba23686e58008faa7798182ef3673b799b7e64 (diff)
lib9p: Pull the IDL parser into a separate file
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen558
1 files changed, 47 insertions, 511 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index ae7f1a5..47d1102 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -5,452 +5,17 @@
# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
-import enum
import os.path
-import re
-from abc import ABC, abstractmethod
-from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast
+import sys
+
+sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))
+
+import idl
# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".
-# Types ########################################################################
-
-
-class Primitive(enum.Enum):
- u8 = 1
- u16 = 2
- u32 = 4
- u64 = 8
-
- @property
- def in_versions(self) -> set[str]:
- return set()
-
- @property
- def name(self) -> str:
- return str(self.value)
-
- @property
- def static_size(self) -> int:
- return self.value
-
-
-class Number:
- name: str
- in_versions: set[str]
-
- prim: Primitive
-
- def __init__(self) -> None:
- self.in_versions = set()
-
- @property
- def static_size(self) -> int:
- return self.prim.static_size
-
-
-class BitfieldVal:
- name: str
- in_versions: set[str]
-
- val: str
-
- def __init__(self) -> None:
- self.in_versions = set()
-
-
-class Bitfield:
- name: str
- in_versions: set[str]
-
- prim: Primitive
-
- bits: list[str] # bitnames
- names: dict[str, BitfieldVal] # bits *and* aliases
-
- def __init__(self) -> None:
- self.in_versions = set()
- self.names = {}
-
- @property
- def static_size(self) -> int:
- return self.prim.static_size
-
- def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool:
- """Return whether the given bit is valid in the given protocol
- version.
-
- """
- bitname = self.bits[bit] if isinstance(bit, int) else bit
- assert bitname in self.bits
- if not bitname:
- return False
- if bitname.startswith("_"):
- return False
- if ver and (ver not in self.names[bitname].in_versions):
- return False
- return True
-
-
-class ExprLit:
- val: int
-
- def __init__(self, val: int) -> None:
- self.val = val
-
-
-class ExprSym:
- name: str
-
- def __init__(self, name: str) -> None:
- self.name = name
-
-
-class ExprOp:
- op: Literal["-", "+"]
-
- def __init__(self, op: Literal["-", "+"]) -> None:
- self.op = op
-
-
-class Expr:
- tokens: list[ExprLit | ExprSym | ExprOp]
-
- def __init__(self) -> None:
- self.tokens = []
-
- def __bool__(self) -> bool:
- return len(self.tokens) > 0
-
-
-class StructMember:
- # from left-to-right when parsing
- cnt: str | None = None
- name: str
- typ: "Type"
- max: Expr
- val: Expr
-
- in_versions: set[str]
-
- @property
- def static_size(self) -> int | None:
- if self.cnt:
- return None
- return self.typ.static_size
-
-
-class Struct:
- name: str
- in_versions: set[str]
-
- members: list[StructMember]
-
- def __init__(self) -> None:
- self.in_versions = set()
-
- @property
- def static_size(self) -> int | None:
- size = 0
- for member in self.members:
- msize = member.static_size
- if msize is None:
- return None
- size += msize
- return size
-
-
-class Message(Struct):
- @property
- def msgid(self) -> int:
- assert len(self.members) >= 3
- assert self.members[1].name == "typ"
- assert self.members[1].static_size == 1
- assert self.members[1].val
- assert len(self.members[1].val.tokens) == 1
- assert isinstance(self.members[1].val.tokens[0], ExprLit)
- return self.members[1].val.tokens[0].val
-
-
-Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message
-# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13
-T = TypeVar("T", Number, Bitfield, Struct, Message)
-
-# Parse *.9p ###################################################################
-
-re_priname = "(?:1|2|4|8)" # primitive names
-re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)" # "symbol" names; most *.9p-defined names
-re_impname = r"(?:\*|" + re_symname + ")" # names we can import
-re_msgname = r"(?:[TR][a-zA-Z_0-9]*)" # names a message can be
-
-re_memtype = f"(?:{re_symname}|{re_priname})" # typenames that a struct member can be
-
-re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)"
-
-re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})"
-re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"
-
-re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?"
-
-
-def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None:
- spec = spec.strip()
-
- bit: int | None
- val: BitfieldVal
- if m := re.fullmatch(re_bitspec_bit, spec):
- bit = int(m.group("bit"))
- name = m.group("name")
-
- val = BitfieldVal()
- val.name = name
- val.val = f"1<<{bit}"
- val.in_versions.add(ver)
-
- if bit < 0 or bit >= len(bf.bits):
- raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
- if bf.bits[bit]:
- raise ValueError(f"{bf.name}: bit {bit} already assigned")
- bf.bits[bit] = val.name
- elif m := re.fullmatch(re_bitspec_alias, spec):
- name = m.group("name")
- valstr = m.group("val")
-
- val = BitfieldVal()
- val.name = name
- val.val = valstr
- val.in_versions.add(ver)
- else:
- raise SyntaxError(f"invalid bitfield spec {repr(spec)}")
-
- if val.name in bf.names:
- raise ValueError(f"{bf.name}: name {val.name} already assigned")
- bf.names[val.name] = val
-
-
-def parse_expr(expr: str) -> Expr:
- assert re.fullmatch(re_expr, expr)
- ret = Expr()
- for tok in re.split("([-+])", expr):
- if tok == "-" or tok == "+":
- # I, for the life of me, do not understand why I need this
- # cast() to keep mypy happy.
- ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))]
- elif re.fullmatch("[0-9]+", tok):
- ret.tokens += [ExprLit(int(tok))]
- else:
- ret.tokens += [ExprSym(tok)]
- return ret
-
-
-def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None:
- for spec in specs.split():
- m = re.fullmatch(re_memberspec, spec)
- if not m:
- raise SyntaxError(f"invalid member spec {repr(spec)}")
-
- member = StructMember()
- member.in_versions = {ver}
-
- member.name = m.group("name")
- if any(x.name == member.name for x in struct.members):
- raise ValueError(f"duplicate member name {repr(member.name)}")
-
- if m.group("typ") not in env:
- raise NameError(f"Unknown type {repr(m.group(2))}")
- member.typ = env[m.group("typ")]
-
- if cnt := m.group("cnt"):
- if len(struct.members) == 0 or struct.members[-1].name != cnt:
- raise ValueError(f"list count must be previous item: {repr(cnt)}")
- if not isinstance(struct.members[-1].typ, Primitive):
- raise ValueError(f"list count must be an integer type: {repr(cnt)}")
- member.cnt = cnt
-
- if maxstr := m.group("max"):
- if (not isinstance(member.typ, Primitive)) or member.cnt:
- raise ValueError("',max=' may only be specified on a non-repeated atom")
- member.max = parse_expr(maxstr)
- else:
- member.max = Expr()
-
- if valstr := m.group("val"):
- if (not isinstance(member.typ, Primitive)) or member.cnt:
- raise ValueError("',val=' may only be specified on a non-repeated atom")
- member.val = parse_expr(valstr)
- else:
- member.val = Expr()
-
- struct.members += [member]
-
-
-def re_string(grpname: str) -> str:
- return f'"(?P<{grpname}>[^"]*)"'
-
-
-re_line_version = f"version\\s+{re_string('version')}"
-re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)"
-re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
-re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
-re_line_bitfield_ = (
- f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}"
-)
-re_line_struct = (
- f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
-)
-re_line_msg = (
- f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
-)
-re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg
-
-
-def parse_file(
- filename: str, get_include: Callable[[str], tuple[str, list[Type]]]
-) -> tuple[str, list[Type]]:
- version: str | None = None
- env: dict[str, Type] = {
- "1": Primitive.u8,
- "2": Primitive.u16,
- "4": Primitive.u32,
- "8": Primitive.u64,
- }
-
- def get_type(name: str, tc: type[T]) -> T:
- nonlocal env
- if name not in env:
- raise NameError(f"Unknown type {repr(name)}")
- ret = env[name]
- if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__):
- raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}")
- return ret
-
- with open(filename, "r") as fh:
- prev: Type | None = None
- for line in fh:
- line = line.split("#", 1)[0].rstrip()
- if not line:
- continue
- if m := re.fullmatch(re_line_version, line):
- if version:
- raise SyntaxError("must have exactly 1 version line")
- version = m.group("version")
- continue
- if not version:
- raise SyntaxError("must have exactly 1 version line")
-
- if m := re.fullmatch(re_line_import, line):
- other_version, other_typs = get_include(m.group("file"))
- for symname in m.group("syms").split(sep=","):
- symname = symname.strip()
- for typ in other_typs:
- if typ.name == symname or symname == "*":
- match typ:
- case Primitive():
- pass
- case Number():
- typ.in_versions.add(version)
- case Bitfield():
- typ.in_versions.add(version)
- for val in typ.names.values():
- if other_version in val.in_versions:
- val.in_versions.add(version)
- case Struct(): # and Message()
- typ.in_versions.add(version)
- for member in typ.members:
- if other_version in member.in_versions:
- member.in_versions.add(version)
- env[typ.name] = typ
- elif m := re.fullmatch(re_line_num, line):
- num = Number()
- num.name = m.group("name")
- num.in_versions.add(version)
-
- prim = env[m.group("prim")]
- assert isinstance(prim, Primitive)
- num.prim = prim
-
- env[num.name] = num
- prev = num
- elif m := re.fullmatch(re_line_bitfield, line):
- bf = Bitfield()
- bf.name = m.group("name")
- bf.in_versions.add(version)
-
- prim = env[m.group("prim")]
- assert isinstance(prim, Primitive)
- bf.prim = prim
-
- bf.bits = (prim.static_size * 8) * [""]
-
- env[bf.name] = bf
- prev = bf
- elif m := re.fullmatch(re_line_bitfield_, line):
- bf = get_type(m.group("name"), Bitfield)
- parse_bitspec(version, bf, m.group("member"))
-
- prev = bf
- elif m := re.fullmatch(re_line_struct, line):
- match m.group("op"):
- case "=":
- struct = Struct()
- struct.name = m.group("name")
- struct.in_versions.add(version)
- struct.members = []
- parse_members(version, env, struct, m.group("members"))
-
- env[struct.name] = struct
- prev = struct
- case "+=":
- struct = get_type(m.group("name"), Struct)
- parse_members(version, env, struct, m.group("members"))
-
- prev = struct
- elif m := re.fullmatch(re_line_msg, line):
- match m.group("op"):
- case "=":
- msg = Message()
- msg.name = m.group("name")
- msg.in_versions.add(version)
- msg.members = []
- parse_members(version, env, msg, m.group("members"))
-
- env[msg.name] = msg
- prev = msg
- case "+=":
- msg = get_type(m.group("name"), Message)
- parse_members(version, env, msg, m.group("members"))
-
- prev = msg
- elif m := re.fullmatch(re_line_cont, line):
- match prev:
- case Bitfield():
- parse_bitspec(version, prev, m.group("specs"))
- case Struct(): # and Message()
- parse_members(version, env, prev, m.group("specs"))
- case _:
- raise SyntaxError(
- "continuation line must come after a bitfield, struct, or msg line"
- )
- else:
- raise SyntaxError(f"invalid line {repr(line)}")
- if not version:
- raise SyntaxError("must have exactly 1 version line")
-
- typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]
-
- for typ in [typ for typ in typs if isinstance(typ, Struct)]:
- valid_syms = ["end", *["&" + m.name for m in typ.members]]
- for member in typ.members:
- for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, ExprSym) and tok.name not in valid_syms:
- raise ValueError(
- f"{typ.name}.{member.name}: invalid sym: {tok.name}"
- )
-
- return version, typs
-
# Generate C ###################################################################
@@ -473,17 +38,17 @@ def c_ver_cond(versions: set[str]) -> str:
return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-def c_typename(typ: Type) -> str:
+def c_typename(typ: idl.Type) -> str:
match typ:
- case Primitive():
+ case idl.Primitive():
return f"uint{typ.value*8}_t"
- case Number():
+ case idl.Number():
return f"{idprefix}{typ.name}_t"
- case Bitfield():
+ case idl.Bitfield():
return f"{idprefix}{typ.name}_t"
- case Message():
+ case idl.Message():
return f"struct {idprefix}msg_{typ.name}"
- case Struct():
+ case idl.Struct():
return f"struct {idprefix}{typ.name}"
case _:
raise ValueError(f"not a type: {typ.__class__.__name__}")
@@ -531,7 +96,7 @@ def ifdef_pop(n: int) -> str:
return ret
-def gen_h(versions: set[str], typs: list[Type]) -> str:
+def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
global _ifdef_stack
_ifdef_stack = []
@@ -576,13 +141,13 @@ enum {idprefix}version {{
ret += """
/* non-message types **********************************************************/
"""
- for typ in [typ for typ in typs if not isinstance(typ, Message)]:
+ for typ in [typ for typ in typs if not isinstance(typ, idl.Message)]:
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
match typ:
- case Number():
+ case idl.Number():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
- case Bitfield():
+ case idl.Bitfield():
ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
names = [
*reversed(
@@ -620,7 +185,7 @@ enum {idprefix}version {{
sp3 = " " * (2 + namewidth - len(name))
ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(typ)})({typ.names[name].val}))\n"
ret += ifdef_pop(1)
- case Struct():
+ case idl.Struct():
typewidth = max(len(c_typename(m.typ)) for m in typ.members)
ret += c_typename(typ) + " {\n"
@@ -641,14 +206,14 @@ enum {idprefix}version {{
"""
ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
- namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message))
- for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message))
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
ret += ifdef_pop(0)
ret += "};\n"
- for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
ret += c_typename(msg) + " {"
@@ -671,22 +236,22 @@ enum {idprefix}version {{
return ret
-def c_expr(expr: Expr) -> str:
+def c_expr(expr: idl.Expr) -> str:
ret: list[str] = []
for tok in expr.tokens:
match tok:
- case ExprOp():
+ case idl.ExprOp():
ret += [tok.op]
- case ExprLit():
+ case idl.ExprLit():
ret += [str(tok.val)]
- case ExprSym(name="end"):
+ case idl.ExprSym(name="end"):
ret += ["ctx->net_offset"]
- case ExprSym():
+ case idl.ExprSym():
ret += [f"_{tok.name[1:]}_offset"]
return " ".join(ret)
-def gen_c(versions: set[str], typs: list[Type]) -> str:
+def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
global _ifdef_stack
_ifdef_stack = []
@@ -771,12 +336,12 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
#define validate_8(ctx) _validate_size_net(ctx, 8)
"""
for typ in typs:
- inline = "LM_FLATTEN" if isinstance(typ, Message) else "LM_ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- if isinstance(typ, Bitfield):
+ if isinstance(typ, idl.Bitfield):
ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
verwidth = max(len(ver) for ver in versions)
for ver in sorted(versions):
@@ -820,9 +385,9 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
continue
match typ:
- case Number():
+ case idl.Number():
ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
- case Bitfield():
+ case idl.Bitfield():
ret += f"\t if (validate_{typ.static_size}(ctx))\n"
ret += "\t\treturn true;\n"
ret += (
@@ -832,7 +397,7 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
ret += f"\tif (val & ~mask)\n"
ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
- case Struct(): # and Message()
+ case idl.Struct(): # and idl.Message()
if len(typ.members) == 0:
ret += "\treturn false;\n"
ret += "}\n"
@@ -849,7 +414,7 @@ LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
mark_offset: set[str] = set()
for member in typ.members:
for tok in [*member.max.tokens, *member.val.tokens]:
- if isinstance(tok, ExprSym) and tok.name.startswith("&"):
+ if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
if tok.name[1:] not in mark_offset:
ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
mark_offset.add(tok.name[1:])
@@ -924,17 +489,17 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
}
"""
for typ in typs:
- inline = "LM_FLATTEN" if isinstance(typ, Message) else "LM_ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
match typ:
- case Number():
+ case idl.Number():
ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
- case Bitfield():
+ case idl.Bitfield():
ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
- case Struct():
+ case idl.Struct():
ret += "\tmemset(out, 0, sizeof(*out));\n"
for member in typ.members:
@@ -1017,18 +582,18 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
}
"""
for typ in typs:
- inline = "LM_FLATTEN" if isinstance(typ, Message) else "LM_ALWAYS_INLINE"
- argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
ret += f"{inline} static bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n"
match typ:
- case Number():
+ case idl.Number():
ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n"
- case Bitfield():
+ case idl.Bitfield():
ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n"
ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n"
- case Struct():
+ case idl.Struct():
if len(typ.members) == 0:
ret += "\treturn false;\n"
ret += "}\n"
@@ -1042,7 +607,7 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
ret += f"\tuint32_t _{member.name}_offset;\n"
mark_offset.add(member.name)
for tok in member.val.tokens:
- if isinstance(tok, ExprSym) and tok.name.startswith("&"):
+ if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
if tok.name[1:] not in mark_offset:
ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
mark_offset.add(tok.name[1:])
@@ -1109,8 +674,8 @@ LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val)
struct _table_version _{idprefix}versions[{c_ver_enum('NUM')}] = {{
"""
- id2typ: dict[int, Message] = {}
- for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ id2typ: dict[int, idl.Message] = {}
+ for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
id2typ[msg.msgid] = msg
for ver in ["unknown", *sorted(versions)]:
@@ -1119,7 +684,7 @@ struct _table_version _{idprefix}versions[{c_ver_enum('NUM')}] = {{
ret += f"\t[{c_ver_enum(ver)}] = {{ .msgs = {{\n"
for n in range(0, 0x100):
- xmsg: Message | None = id2typ.get(n, None)
+ xmsg: idl.Message | None = id2typ.get(n, None)
if xmsg:
if ver == "unknown": # SPECIAL
if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
@@ -1154,41 +719,12 @@ LM_FLATTEN bool _{idprefix}marshal_stat(struct _marshal_ctx *ctx, struct lib9p_s
################################################################################
-class Parser:
- cache: dict[str, tuple[str, list[Type]]] = {}
-
- def parse_file(self, filename: str) -> tuple[str, list[Type]]:
- filename = os.path.normpath(filename)
- if filename not in self.cache:
-
- def get_include(other_filename: str) -> tuple[str, list[Type]]:
- return self.parse_file(os.path.join(filename, "..", other_filename))
-
- self.cache[filename] = parse_file(filename, get_include)
- return self.cache[filename]
-
- def all(self) -> tuple[set[str], list[Type]]:
- ret_versions: set[str] = set()
- ret_typs: dict[str, Type] = {}
- for version, typs in self.cache.values():
- if version in ret_versions:
- raise ValueError(f"duplicate protocol version {repr(version)}")
- ret_versions.add(version)
- for typ in typs:
- if typ.name in ret_typs:
- if typ != ret_typs[typ.name]:
- raise ValueError(f"duplicate type name {repr(typ.name)}")
- else:
- ret_typs[typ.name] = typ
- return ret_versions, list(ret_typs.values())
-
-
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
raise ValueError("requires at least 1 .9p filename")
- parser = Parser()
+ parser = idl.Parser()
for txtname in sys.argv[1:]:
parser.parse_file(txtname)
versions, typs = parser.all()