summaryrefslogtreecommitdiff
path: root/lib9p/protogen/idlutil.py
blob: dc4d012b35b7b52469b16d27c7b0364cb5e87a58 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# lib9p/protogen/idlutil.py - Utilities for working with the 9P idl package
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import enum
import graphlib
import typing

import idl

# pylint: disable=unused-variable
__all__ = [
    "topo_sorted",
    "Path",
    "WalkCmd",
    "WalkHandler",
    "walk",
]

# topo_sorted() ################################################################


def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]:
    ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter()
    for typ in typs:
        match typ:
            case idl.Number():
                ts.add(typ)
            case idl.Bitfield():
                ts.add(typ)
            case idl.Struct():  # and idl.Message():
                deps = [
                    member.typ
                    for member in typ.members
                    if not isinstance(member.typ, idl.Primitive)
                ]
                ts.add(typ, *deps)
    return ts.static_order()


# walk() #######################################################################


class Path:
    root: idl.Type
    elems: list[idl.StructMember]

    def __init__(
        self, root: idl.Type, elems: list[idl.StructMember] | None = None
    ) -> None:
        self.root = root
        self.elems = elems if elems is not None else []

    def add(self, elem: idl.StructMember) -> "Path":
        return Path(self.root, self.elems + [elem])

    def parent(self) -> "Path":
        return Path(self.root, self.elems[:-1])

    def c_str(self, base: str, loopdepth: int = 0) -> str:
        ret = base
        for i, elem in enumerate(self.elems):
            if i > 0:
                ret += "."
            ret += elem.membname
            if elem.cnt:
                ret += f"[{chr(ord('i')+loopdepth)}]"
                loopdepth += 1
        return ret

    def __str__(self) -> str:
        return self.c_str(self.root.typname + "->")


class WalkCmd(enum.Enum):
    KEEP_GOING = 1
    DONT_RECURSE = 2
    ABORT = 3


type WalkHandler = typing.Callable[
    [Path], tuple[WalkCmd, typing.Callable[[], None] | None]
]


def _walk(path: Path, handle: WalkHandler) -> WalkCmd:
    typ = path.elems[-1].typ if path.elems else path.root

    ret, atexit = handle(path)

    if isinstance(typ, idl.Struct):
        match ret:
            case WalkCmd.KEEP_GOING:
                for member in typ.members:
                    if _walk(path.add(member), handle) == WalkCmd.ABORT:
                        ret = WalkCmd.ABORT
                        break
            case WalkCmd.DONT_RECURSE:
                ret = WalkCmd.KEEP_GOING
            case WalkCmd.ABORT:
                ret = WalkCmd.ABORT
            case _:
                assert False, f"invalid cmd: {ret}"

    if atexit:
        atexit()
    return ret


def walk(typ: idl.Type, handle: WalkHandler) -> None:
    _walk(Path(typ), handle)