1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
# lib9p/core_gen/idlutil.py - Utilities for working with the 9P idl package
#
# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
import enum
import typing
import idl
# pylint: disable=unused-variable
__all__ = [
"topo_sorted",
"Path",
"WalkCmd",
"WalkHandler",
"walk",
]
# topo_sorted() ################################################################
def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
ret: list[idl.UserType] = []
struct_ord: dict[str, int] = {}
def get_struct_ord(typ: idl.Struct) -> int:
nonlocal struct_ord
if typ.typname not in struct_ord:
deps = [
get_struct_ord(member.typ)
for member in typ.members
if isinstance(member.typ, idl.Struct)
]
if len(deps) == 0:
struct_ord[typ.typname] = 0
else:
struct_ord[typ.typname] = 1 + max(deps)
return struct_ord[typ.typname]
for typ in typs:
match typ:
case idl.Number():
ret.append(typ)
case idl.Bitfield():
ret.append(typ)
case idl.Struct(): # and idl.Message():
_ = get_struct_ord(typ)
for _ord in sorted(set(struct_ord.values())):
for typ in typs:
if not isinstance(typ, idl.Struct):
continue
if struct_ord[typ.typname] != _ord:
continue
ret.append(typ)
assert len(ret) == len(typs)
return ret
# walk() #######################################################################
class Path:
root: idl.Type
elems: list[idl.StructMember]
def __init__(
self, root: idl.Type, elems: list[idl.StructMember] | None = None
) -> None:
self.root = root
self.elems = elems if elems is not None else []
def add(self, elem: idl.StructMember) -> "Path":
return Path(self.root, self.elems + [elem])
def parent(self) -> "Path":
return Path(self.root, self.elems[:-1])
def c_str(self, base: str, loopdepth: int = 0) -> str:
ret = base
for i, elem in enumerate(self.elems):
if i > 0:
ret += "."
ret += elem.membname
if elem.cnt:
ret += f"[{chr(ord('i')+loopdepth)}]"
loopdepth += 1
return ret
def __str__(self) -> str:
return self.c_str(self.root.typname + "->")
class WalkCmd(enum.Enum):
KEEP_GOING = 1
DONT_RECURSE = 2
ABORT = 3
type WalkHandler = typing.Callable[
[Path], tuple[WalkCmd, typing.Callable[[], None] | None]
]
def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
typ = path.elems[-1].typ if path.elems else path.root
ret, atexit = handle(path)
if isinstance(typ, idl.Struct):
match ret:
case WalkCmd.KEEP_GOING:
for member in typ.members:
if _walk(path.add(member), handle) == WalkCmd.ABORT:
ret = WalkCmd.ABORT
break
case WalkCmd.DONT_RECURSE:
ret = WalkCmd.KEEP_GOING
case WalkCmd.ABORT:
ret = WalkCmd.ABORT
case _:
assert False, f"invalid cmd: {ret}"
if atexit:
atexit()
return ret
def walk(typ: idl.Type, handle: WalkHandler) -> None:
_walk(Path(typ), handle)
|