summaryrefslogtreecommitdiff
path: root/lib9p/core_gen/idlutil.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib9p/core_gen/idlutil.py')
-rw-r--r--lib9p/core_gen/idlutil.py130
1 files changed, 130 insertions, 0 deletions
diff --git a/lib9p/core_gen/idlutil.py b/lib9p/core_gen/idlutil.py
new file mode 100644
index 0000000..e92839a
--- /dev/null
+++ b/lib9p/core_gen/idlutil.py
@@ -0,0 +1,130 @@
+# lib9p/core_gen/idlutil.py - Utilities for working with the 9P idl package
+#
+# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import enum
+import typing
+
+import idl
+
+# pylint: disable=unused-variable
+__all__ = [
+ "topo_sorted",
+ "Path",
+ "WalkCmd",
+ "WalkHandler",
+ "walk",
+]
+
+# topo_sorted() ################################################################
+
+
+def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
+ ret: list[idl.UserType] = []
+ struct_ord: dict[str, int] = {}
+
+ def get_struct_ord(typ: idl.Struct) -> int:
+ nonlocal struct_ord
+ if typ.typname not in struct_ord:
+ deps = [
+ get_struct_ord(member.typ)
+ for member in typ.members
+ if isinstance(member.typ, idl.Struct)
+ ]
+ if len(deps) == 0:
+ struct_ord[typ.typname] = 0
+ else:
+ struct_ord[typ.typname] = 1 + max(deps)
+ return struct_ord[typ.typname]
+
+ for typ in typs:
+ match typ:
+ case idl.Number():
+ ret.append(typ)
+ case idl.Bitfield():
+ ret.append(typ)
+ case idl.Struct(): # and idl.Message():
+ _ = get_struct_ord(typ)
+ for _ord in sorted(set(struct_ord.values())):
+ for typ in typs:
+ if not isinstance(typ, idl.Struct):
+ continue
+ if struct_ord[typ.typname] != _ord:
+ continue
+ ret.append(typ)
+ assert len(ret) == len(typs)
+ return ret
+
+
+# walk() #######################################################################
+
+
+class Path:
+ root: idl.Type
+ elems: list[idl.StructMember]
+
+ def __init__(
+ self, root: idl.Type, elems: list[idl.StructMember] | None = None
+ ) -> None:
+ self.root = root
+ self.elems = elems if elems is not None else []
+
+ def add(self, elem: idl.StructMember) -> "Path":
+ return Path(self.root, self.elems + [elem])
+
+ def parent(self) -> "Path":
+ return Path(self.root, self.elems[:-1])
+
+ def c_str(self, base: str, loopdepth: int = 0) -> str:
+ ret = base
+ for i, elem in enumerate(self.elems):
+ if i > 0:
+ ret += "."
+ ret += elem.membname
+ if elem.cnt:
+ ret += f"[{chr(ord('i')+loopdepth)}]"
+ loopdepth += 1
+ return ret
+
+ def __str__(self) -> str:
+ return self.c_str(self.root.typname + "->")
+
+
+class WalkCmd(enum.Enum):
+ KEEP_GOING = 1
+ DONT_RECURSE = 2
+ ABORT = 3
+
+
+type WalkHandler = typing.Callable[
+ [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
+]
+
+
+def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
+ typ = path.elems[-1].typ if path.elems else path.root
+
+ ret, atexit = handle(path)
+
+ if isinstance(typ, idl.Struct):
+ match ret:
+ case WalkCmd.KEEP_GOING:
+ for member in typ.members:
+ if _walk(path.add(member), handle) == WalkCmd.ABORT:
+ ret = WalkCmd.ABORT
+ break
+ case WalkCmd.DONT_RECURSE:
+ ret = WalkCmd.KEEP_GOING
+ case WalkCmd.ABORT:
+ ret = WalkCmd.ABORT
+ case _:
+ assert False, f"invalid cmd: {ret}"
+
+ if atexit:
+ atexit()
+ return ret
+
+
+def walk(typ: idl.Type, handle: WalkHandler) -> None:
+ _walk(Path(typ), handle)