summaryrefslogtreecommitdiff
path: root/lib9p/core_gen/idlutil.py
blob: e92839a84f0b1df71bd96d06cd6df4668ca8eca8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# lib9p/core_gen/idlutil.py - Utilities for working with the 9P idl package
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import enum
import typing

import idl

# pylint: disable=unused-variable
__all__ = [
    "topo_sorted",
    "Path",
    "WalkCmd",
    "WalkHandler",
    "walk",
]

# topo_sorted() ################################################################


def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
    ret: list[idl.UserType] = []
    struct_ord: dict[str, int] = {}

    def get_struct_ord(typ: idl.Struct) -> int:
        nonlocal struct_ord
        if typ.typname not in struct_ord:
            deps = [
                get_struct_ord(member.typ)
                for member in typ.members
                if isinstance(member.typ, idl.Struct)
            ]
            if len(deps) == 0:
                struct_ord[typ.typname] = 0
            else:
                struct_ord[typ.typname] = 1 + max(deps)
        return struct_ord[typ.typname]

    for typ in typs:
        match typ:
            case idl.Number():
                ret.append(typ)
            case idl.Bitfield():
                ret.append(typ)
            case idl.Struct():  # and idl.Message():
                _ = get_struct_ord(typ)
    for _ord in sorted(set(struct_ord.values())):
        for typ in typs:
            if not isinstance(typ, idl.Struct):
                continue
            if struct_ord[typ.typname] != _ord:
                continue
            ret.append(typ)
    assert len(ret) == len(typs)
    return ret


# walk() #######################################################################


class Path:
    root: idl.Type
    elems: list[idl.StructMember]

    def __init__(
        self, root: idl.Type, elems: list[idl.StructMember] | None = None
    ) -> None:
        self.root = root
        self.elems = elems if elems is not None else []

    def add(self, elem: idl.StructMember) -> "Path":
        return Path(self.root, self.elems + [elem])

    def parent(self) -> "Path":
        return Path(self.root, self.elems[:-1])

    def c_str(self, base: str, loopdepth: int = 0) -> str:
        ret = base
        for i, elem in enumerate(self.elems):
            if i > 0:
                ret += "."
            ret += elem.membname
            if elem.cnt:
                ret += f"[{chr(ord('i')+loopdepth)}]"
                loopdepth += 1
        return ret

    def __str__(self) -> str:
        return self.c_str(self.root.typname + "->")


class WalkCmd(enum.Enum):
    KEEP_GOING = 1
    DONT_RECURSE = 2
    ABORT = 3


type WalkHandler = typing.Callable[
    [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
]


def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
    typ = path.elems[-1].typ if path.elems else path.root

    ret, atexit = handle(path)

    if isinstance(typ, idl.Struct):
        match ret:
            case WalkCmd.KEEP_GOING:
                for member in typ.members:
                    if _walk(path.add(member), handle) == WalkCmd.ABORT:
                        ret = WalkCmd.ABORT
                        break
            case WalkCmd.DONT_RECURSE:
                ret = WalkCmd.KEEP_GOING
            case WalkCmd.ABORT:
                ret = WalkCmd.ABORT
            case _:
                assert False, f"invalid cmd: {ret}"

    if atexit:
        atexit()
    return ret


def walk(typ: idl.Type, handle: WalkHandler) -> None:
    _walk(Path(typ), handle)