From ee356d885a984d5e79da0da20ce608787f1426f3 Mon Sep 17 00:00:00 2001
From: "Luke T. Shumaker" <lukeshu@lukeshu.com>
Date: Sun, 23 Mar 2025 02:05:13 -0600
Subject: lib9p: protogen: pull idlutil.py out of __init__.py

---
 lib9p/protogen/idlutil.py | 112 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 112 insertions(+)
 create mode 100644 lib9p/protogen/idlutil.py

(limited to 'lib9p/protogen/idlutil.py')

diff --git a/lib9p/protogen/idlutil.py b/lib9p/protogen/idlutil.py
new file mode 100644
index 0000000..dc4d012
--- /dev/null
+++ b/lib9p/protogen/idlutil.py
@@ -0,0 +1,112 @@
+# lib9p/protogen/idlutil.py - Utilities for working with the 9P idl package
+#
+# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import enum
+import graphlib
+import typing
+
+import idl
+
+# pylint: disable=unused-variable
+__all__ = [
+    "topo_sorted",
+    "Path",
+    "WalkCmd",
+    "WalkHandler",
+    "walk",
+]
+
+# topo_sorted() ################################################################
+
+
+def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
+    ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter()
+    for typ in typs:
+        match typ:
+            case idl.Number():
+                ts.add(typ)
+            case idl.Bitfield():
+                ts.add(typ)
+            case idl.Struct():  # and idl.Message():
+                deps = [
+                    member.typ
+                    for member in typ.members
+                    if not isinstance(member.typ, idl.Primitive)
+                ]
+                ts.add(typ, *deps)
+    return ts.static_order()
+
+
+# walk() #######################################################################
+
+
+class Path:
+    root: idl.Type
+    elems: list[idl.StructMember]
+
+    def __init__(
+        self, root: idl.Type, elems: list[idl.StructMember] | None = None
+    ) -> None:
+        self.root = root
+        self.elems = elems if elems is not None else []
+
+    def add(self, elem: idl.StructMember) -> "Path":
+        return Path(self.root, self.elems + [elem])
+
+    def parent(self) -> "Path":
+        return Path(self.root, self.elems[:-1])
+
+    def c_str(self, base: str, loopdepth: int = 0) -> str:
+        ret = base
+        for i, elem in enumerate(self.elems):
+            if i > 0:
+                ret += "."
+            ret += elem.membname
+            if elem.cnt:
+                ret += f"[{chr(ord('i')+loopdepth)}]"
+                loopdepth += 1
+        return ret
+
+    def __str__(self) -> str:
+        return self.c_str(self.root.typname + "->")
+
+
+class WalkCmd(enum.Enum):
+    KEEP_GOING = 1
+    DONT_RECURSE = 2
+    ABORT = 3
+
+
+type WalkHandler = typing.Callable[
+    [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
+]
+
+
+def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
+    typ = path.elems[-1].typ if path.elems else path.root
+
+    ret, atexit = handle(path)
+
+    if isinstance(typ, idl.Struct):
+        match ret:
+            case WalkCmd.KEEP_GOING:
+                for member in typ.members:
+                    if _walk(path.add(member), handle) == WalkCmd.ABORT:
+                        ret = WalkCmd.ABORT
+                        break
+            case WalkCmd.DONT_RECURSE:
+                ret = WalkCmd.KEEP_GOING
+            case WalkCmd.ABORT:
+                ret = WalkCmd.ABORT
+            case _:
+                assert False, f"invalid cmd: {ret}"
+
+    if atexit:
+        atexit()
+    return ret
+
+
+def walk(typ: idl.Type, handle: WalkHandler) -> None:
+    _walk(Path(typ), handle)
-- 
cgit v1.2.3-2-g168b