summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-15 14:04:10 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-22 19:18:38 -0600
commitd912a4d79ed9e51e5dfcc24e6445c1de7dbb1a30 (patch)
tree4d9604cdfa7269cf827b87301e9a685a55c6da2e
parent185c3329145959433b8b805de5f114b66b8fcaee (diff)
lib9p: idl.gen: Have a separate type that excludes idl.Primitive
-rwxr-xr-xlib9p/idl.gen14
-rw-r--r--lib9p/idl/__init__.py17
2 files changed, 16 insertions, 15 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index 779b6d5..b75ffd6 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -147,8 +147,8 @@ def ifdef_pop(n: int) -> str:
# topo_sorted() ################################################################
-def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]:
- ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter()
+def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
+ ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter()
for typ in typs:
match typ:
case idl.Number():
@@ -375,7 +375,7 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
# Generate .h ##################################################################
-def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
+def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
global _ifdef_stack
_ifdef_stack = []
@@ -451,7 +451,7 @@ enum {idprefix}version {{
"""
def per_version_comment(
- typ: idl.Type, fn: typing.Callable[[idl.Type, str], str]
+ typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str]
) -> str:
lines: dict[str, str] = {}
for version in sorted(typ.in_versions):
@@ -471,7 +471,7 @@ enum {idprefix}version {{
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
- def sum_size(typ: idl.Type, version: str) -> str:
+ def sum_size(typ: idl.UserType, version: str) -> str:
sz = get_buffer_size(typ, version)
assert (
sz.min_size <= sz.exp_size
@@ -643,7 +643,7 @@ enum {idprefix}version {{
# Generate .c ##################################################################
-def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
+def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
global _ifdef_stack
_ifdef_stack = []
@@ -1088,7 +1088,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o
type OffsetExprRecursion = typing.Callable[[Path], WalkCmd]
- def get_offset_expr(typ: idl.Type, recurse: OffsetExprRecursion) -> OffsetExpr:
+ def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
if not isinstance(typ, idl.Struct):
assert typ.static_size
ret = OffsetExpr()
diff --git a/lib9p/idl/__init__.py b/lib9p/idl/__init__.py
index 042a438..04e1791 100644
--- a/lib9p/idl/__init__.py
+++ b/lib9p/idl/__init__.py
@@ -271,6 +271,7 @@ class Message(Struct):
type Type = Primitive | Number | Bitfield | Struct | Message
+type UserType = Number | Bitfield | Struct | Message
T = typing.TypeVar("T", Number, Bitfield, Struct, Message)
# Parse ########################################################################
@@ -417,8 +418,8 @@ re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg
def parse_file(
- filename: str, get_include: typing.Callable[[str], tuple[str, list[Type]]]
-) -> tuple[str, list[Type]]:
+ filename: str, get_include: typing.Callable[[str], tuple[str, list[UserType]]]
+) -> tuple[str, list[UserType]]:
version: str | None = None
env: dict[str, Type] = {
"1": Primitive.u8,
@@ -577,7 +578,7 @@ def parse_file(
if not version:
raise SyntaxError("must have exactly 1 version line")
- typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]
+ typs: list[UserType] = [x for x in env.values() if not isinstance(x, Primitive)]
for typ in [typ for typ in typs if isinstance(typ, Struct)]:
valid_syms = [
@@ -607,21 +608,21 @@ def parse_file(
class Parser:
- cache: dict[str, tuple[str, list[Type]]] = {}
+ cache: dict[str, tuple[str, list[UserType]]] = {}
- def parse_file(self, filename: str) -> tuple[str, list[Type]]:
+ def parse_file(self, filename: str) -> tuple[str, list[UserType]]:
filename = os.path.normpath(filename)
if filename not in self.cache:
- def get_include(other_filename: str) -> tuple[str, list[Type]]:
+ def get_include(other_filename: str) -> tuple[str, list[UserType]]:
return self.parse_file(os.path.join(filename, "..", other_filename))
self.cache[filename] = parse_file(filename, get_include)
return self.cache[filename]
- def all(self) -> tuple[set[str], list[Type]]:
+ def all(self) -> tuple[set[str], list[UserType]]:
ret_versions: set[str] = set()
- ret_typs: dict[str, Type] = {}
+ ret_typs: dict[str, UserType] = {}
for version, typs in self.cache.values():
if version in ret_versions:
raise ValueError(f"duplicate protocol version {repr(version)}")