diff options
Diffstat (limited to 'lib9p/core_gen/idlutil.py')
-rw-r--r-- | lib9p/core_gen/idlutil.py | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/lib9p/core_gen/idlutil.py b/lib9p/core_gen/idlutil.py new file mode 100644 index 0000000..e92839a --- /dev/null +++ b/lib9p/core_gen/idlutil.py @@ -0,0 +1,130 @@ +# lib9p/core_gen/idlutil.py - Utilities for working with the 9P idl package +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import enum +import typing + +import idl + +# pylint: disable=unused-variable +__all__ = [ + "topo_sorted", + "Path", + "WalkCmd", + "WalkHandler", + "walk", +] + +# topo_sorted() ################################################################ + + +def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]: + ret: list[idl.UserType] = [] + struct_ord: dict[str, int] = {} + + def get_struct_ord(typ: idl.Struct) -> int: + nonlocal struct_ord + if typ.typname not in struct_ord: + deps = [ + get_struct_ord(member.typ) + for member in typ.members + if isinstance(member.typ, idl.Struct) + ] + if len(deps) == 0: + struct_ord[typ.typname] = 0 + else: + struct_ord[typ.typname] = 1 + max(deps) + return struct_ord[typ.typname] + + for typ in typs: + match typ: + case idl.Number(): + ret.append(typ) + case idl.Bitfield(): + ret.append(typ) + case idl.Struct(): # and idl.Message(): + _ = get_struct_ord(typ) + for _ord in sorted(set(struct_ord.values())): + for typ in typs: + if not isinstance(typ, idl.Struct): + continue + if struct_ord[typ.typname] != _ord: + continue + ret.append(typ) + assert len(ret) == len(typs) + return ret + + +# walk() ####################################################################### + + +class Path: + root: idl.Type + elems: list[idl.StructMember] + + def __init__( + self, root: idl.Type, elems: list[idl.StructMember] | None = None + ) -> None: + self.root = root + self.elems = elems if elems is not None else [] + + def add(self, elem: idl.StructMember) -> "Path": + return Path(self.root, self.elems + [elem]) + + def parent(self) -> "Path": + return Path(self.root, self.elems[:-1]) + + def c_str(self, base: str, loopdepth: int = 0) -> str: + ret = base + for i, elem in enumerate(self.elems): + if i > 0: + ret += "." + ret += elem.membname + if elem.cnt: + ret += f"[{chr(ord('i')+loopdepth)}]" + loopdepth += 1 + return ret + + def __str__(self) -> str: + return self.c_str(self.root.typname + "->") + + +class WalkCmd(enum.Enum): + KEEP_GOING = 1 + DONT_RECURSE = 2 + ABORT = 3 + + +type WalkHandler = typing.Callable[ + [Path], tuple[WalkCmd, typing.Callable[[], None] | None] +] + + +def _walk(path: Path, handle: WalkHandler) -> WalkCmd: + typ = path.elems[-1].typ if path.elems else path.root + + ret, atexit = handle(path) + + if isinstance(typ, idl.Struct): + match ret: + case WalkCmd.KEEP_GOING: + for member in typ.members: + if _walk(path.add(member), handle) == WalkCmd.ABORT: + ret = WalkCmd.ABORT + break + case WalkCmd.DONT_RECURSE: + ret = WalkCmd.KEEP_GOING + case WalkCmd.ABORT: + ret = WalkCmd.ABORT + case _: + assert False, f"invalid cmd: {ret}" + + if atexit: + atexit() + return ret + + +def walk(typ: idl.Type, handle: WalkHandler) -> None: + _walk(Path(typ), handle) |