#!/usr/bin/env python3
# build-aux/stack.c.gen - Analyze stack sizes for compiled objects
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import os.path
import re
import sys
import typing

################################################################################
#
# Parse the "VCG" language
#
# https://www.rw.cdl.uni-saarland.de/people/sander/private/html/gsvcg1.html
#
# The formal syntax is found at
# ftp://ftp.cs.uni-sb.de/pub/graphics/vcg/vcg.tgz `doc/grammar.txt`.


class VCGElem:
    typ: str
    lineno: int
    attrs: dict[str, str]


def parse_vcg(reader: typing.TextIO) -> typing.Iterator[VCGElem]:
    re_beg = re.compile(r"(edge|node):\s*\{\s*")
    _re_tok = r"[a-zA-Z_][a-zA-Z0-9_]*"
    _re_str = r'"(?:[^\"]|\\.)*"'
    re_attr = re.compile(
        "(" + _re_tok + r")\s*:\s*(" + _re_tok + "|" + _re_str + r")\s*"
    )
    re_end = re.compile(r"\}\s*$")
    re_skip = re.compile(r"(graph:\s*\{\s*title\s*:\s*" + _re_str + r"\s*|\})\s*")
    re_esc = re.compile(r"\\.")

    for lineno, line in enumerate(reader):
        pos = 0

        def _raise(msg: str) -> typing.NoReturn:
            nonlocal lineno
            nonlocal line
            nonlocal pos
            e = SyntaxError(msg)
            e.lineno = lineno
            e.offset = pos
            e.text = line
            raise e

        if re_skip.fullmatch(line):
            continue

        elem = VCGElem()
        elem.lineno = lineno

        m = re_beg.match(line, pos=pos)
        if not m:
            _raise("does not look like a VCG line")
        elem.typ = m.group(1)
        pos = m.end()

        elem.attrs = {}
        while True:
            if re_end.match(line, pos=pos):
                break
            m = re_attr.match(line, pos=pos)
            if not m:
                _raise("unexpected character")
            k = m.group(1)
            v = m.group(2)
            if k in elem.attrs:
                _raise(f"duplicate key: {repr(k)}")
            if v.startswith('"'):

                def unesc(esc: re.Match[str]) -> str:
                    match esc.group(0)[1:]:
                        case "n":
                            return "\n"
                        case '"':
                            return '"'
                        case "\\":
                            return "\\"
                        case _:
                            _raise(f"invalid escape code {repr(esc.group(0))}")

                v = re_esc.sub(unesc, v[1:-1])
            elem.attrs[k] = v
            pos = m.end()

        yield elem


################################################################################
# Main analysis

UsageKind: typing.TypeAlias = typing.Literal["static", "dynamic", "dynamic,bounded"]


class Node:
    # from .title (`static` and `__weak` functions are prefixed with
    # the compilation unit .c file.  For static functions that's fine,
    # but we'll have to handle it specially for __weak.).
    funcname: str
    # .label is "{funcname}\n{location}\n{nstatic} bytes (static}\n{ndynamic} dynamic objects"
    location: str
    usage_kind: UsageKind
    nstatic: int
    ndynamic: int

    # edges with .sourcename set to this node, val is if it's
    # OK/expected that the function be missing.
    calls: dict[str, bool]


def synthetic_node(
    name: str, nstatic: int, calls: typing.Collection[str] = set()
) -> Node:
    n = Node()

    n.funcname = name

    n.location = "<synthetic>"
    n.usage_kind = "static"
    n.nstatic = nstatic
    n.ndynamic = 0

    n.calls = dict((c, False) for c in calls)

    return n


def analyze(
    *,
    ci_fnames: list[str],
    extra_nodes: list[Node] = [],
    app_func_filters: dict[str, typing.Callable[[str], int]],
    app_location_xform: typing.Callable[[str], str],
    app_indirect_callees: typing.Callable[[VCGElem], list[str]],
    app_skip_call: typing.Callable[[list[str], str], bool],
    cfg_max_call_depth: int,
) -> None:
    re_node_label = re.compile(
        r"(?P<funcname>[^\n]+)\n"
        + r"(?P<location>[^\n]+:[0-9]+:[0-9]+)\n"
        + r"(?P<nstatic>[0-9]+) bytes \((?P<usage_kind>static|dynamic|dynamic,bounded)\)\n"
        + r"(?P<ndynamic>[0-9]+) dynamic objects"
        + r"(?:\n.*)?",
        flags=re.MULTILINE,
    )

    graph: dict[str, Node] = dict()
    qualified: dict[str, set[str]] = dict()

    def handle_elem(elem: VCGElem) -> None:
        match elem.typ:
            case "node":
                node = Node()
                node.calls = {}
                skip = False
                for k, v in elem.attrs.items():
                    match k:
                        case "title":
                            node.funcname = v
                        case "label":
                            if elem.attrs.get("shape", "") != "ellipse":
                                m = re_node_label.fullmatch(v)
                                if not m:
                                    raise ValueError(
                                        f"unexpected label value {repr(v)}"
                                    )
                                node.location = m.group("location")
                                node.usage_kind = typing.cast(
                                    UsageKind, m.group("usage_kind")
                                )
                                node.nstatic = int(m.group("nstatic"))
                                node.ndynamic = int(m.group("ndynamic"))
                        case "shape":
                            if v != "ellipse":
                                raise ValueError(f"unexpected shape value {repr(v)}")
                            skip = True
                        case _:
                            raise ValueError(f"unknown edge key {repr(k)}")
                if not skip:
                    if node.funcname in graph:
                        raise ValueError(f"duplicate node {repr(node.funcname)}")
                    graph[node.funcname] = node
                    if ":" in node.funcname:
                        _, shortname = node.funcname.rsplit(":", 1)
                        if shortname not in qualified:
                            qualified[shortname] = set()
                        qualified[shortname].add(node.funcname)
            case "edge":
                caller: str | None = None
                callee: str | None = None
                for k, v in elem.attrs.items():
                    match k:
                        case "sourcename":
                            caller = v
                        case "targetname":
                            callee = v
                        case "label":
                            pass
                        case _:
                            raise ValueError(f"unknown edge key {repr(k)}")
                if caller is None or callee is None:
                    raise ValueError(f"incomplete edge: {repr(elem.attrs)}")
                if caller not in graph:
                    raise ValueError(f"unknown caller: {caller}")
                if callee == "__indirect_call":
                    for callee in app_indirect_callees(elem):
                        if callee not in graph[caller].calls:
                            graph[caller].calls[callee] = True
                else:
                    graph[caller].calls[callee] = False
            case _:
                raise ValueError(f"unknown elem type {repr(elem.typ)}")

    for ci_fname in ci_fnames:
        with open(ci_fname, "r") as fh:
            for elem in parse_vcg(fh):
                handle_elem(elem)

    for node in extra_nodes:
        if node.funcname in graph:
            raise ValueError(f"duplicate node {repr(node.funcname)}")
        graph[node.funcname] = node

    missing: set[str] = set()
    dynamic: set[str] = set()

    dbg = False

    def resolve_funcname(funcname: str) -> str | None:
        # Handle `ld --wrap` functions
        if f"__wrap_{funcname}" in graph:
            return f"__wrap_{funcname}"
        if funcname.startswith("__real_") and funcname[len("__real_") :] in graph:
            funcname = funcname[len("__real_") :]

        # Usual case
        if funcname in graph:
            return funcname

        # Handle `__weak` functions
        if funcname in qualified and len(qualified[funcname]) == 1:
            return sorted(qualified[funcname])[0]

        return None

    def nstatic(
        orig_funcname: str, chain: list[str] = [], missing_ok: bool = False
    ) -> int:
        nonlocal dbg
        funcname = resolve_funcname(orig_funcname)
        if not funcname:
            if app_skip_call(chain, orig_funcname):
                return 0
            if not missing_ok:
                missing.add(orig_funcname)
            return 0
        if app_skip_call(chain, funcname):
            return 0

        if len(chain) == cfg_max_call_depth:
            raise ValueError(f"max call depth exceeded: {chain+[funcname]}")

        node = graph[funcname]
        if dbg:
            print(f"//dbg: {funcname}\t{node.nstatic}")
        if node.usage_kind == "dynamic" or node.ndynamic > 0:
            dynamic.add(app_location_xform(funcname))
        return node.nstatic + max(
            [
                0,
                *[
                    nstatic(call, chain + [funcname], missing_ok)
                    for call, missing_ok in node.calls.items()
                ],
            ]
        )

    print("/*")

    for grp_name, grp_filter in app_func_filters.items():
        # Gather the data.
        nmax = 0
        nsum = 0
        rows: dict[str, int] = {}
        for funcname in graph:
            if cnt := grp_filter(funcname):
                n = nstatic(funcname)
                rows[app_location_xform(funcname)] = n
                if n > nmax:
                    nmax = n
                nsum += cnt * n

        # Figure sizes.
        namelen = max([len(k) for k in rows.keys()] + [len(grp_name) + 4])
        numlen = len(str(nsum))
        sep1 = ("=" * namelen) + " " + "=" * numlen
        sep2 = ("-" * namelen) + " " + "-" * numlen

        # Print.
        print("= " + grp_name + " " + sep1[len(grp_name) + 3 :])
        for name, num in sorted(rows.items()):
            print(f"{name.ljust(namelen)} {str(num).rjust(numlen)}")
        print(sep2)
        print(f"{'Total'.ljust(namelen)} {str(nsum).rjust(numlen)}")
        print(f"{'Maximum'.ljust(namelen)} {str(nmax).rjust(numlen)}")
        print(sep1)

    for funcname in sorted(missing):
        print(f"warning: missing: {funcname}")
    for funcname in sorted(dynamic):
        print(f"warning: dynamic-stack-usage: {funcname}")

    print("*/")


################################################################################
# Application-specific code

re_location = re.compile(r"(?P<filename>.+):(?P<row>[0-9]+):(?P<col>[0-9]+)")


def read_source(location: str) -> str:
    m = re_location.fullmatch(location)
    if not m:
        raise ValueError(f"unexpected label value {repr(location)}")
    filename = m.group("filename")
    row = int(m.group("row")) - 1
    col = int(m.group("col")) - 1
    with open(m.group("filename"), "r") as fh:
        return fh.readlines()[row][col:].rstrip()


def main(
    *,
    arg_pico_platform: str,
    arg_base_dir: str,
    arg_ci_fnames: list[str],
    arg_c_fnames: list[str],
) -> None:

    re_call_other = re.compile(r"(?P<func>[^(]+)\(.*")

    all_nodes: list[Node] = []
    hooks_is_intrhandler: list[typing.Callable[[str], bool]] = []
    hooks_indirect_callees: list[typing.Callable[[str, str], list[str] | None]] = []
    hooks_skip_call: list[typing.Callable[[list[str], str], bool]] = []

    # The sbc-harness codebase #######################################

    objcalls: dict[str, set[str]] = {}
    re_vtable_start = re.compile(r"_vtable\s*=\s*\{")
    re_vtable_entry = re.compile(r"^\s+\.(?P<meth>\S+)\s*=\s*(?P<impl>\S+),.*")
    for fname in c_fnames:
        with open(fname, "r") as fh:
            in_vtable = False
            for line in fh:
                line = line.rstrip()
                if in_vtable:
                    if m := re_vtable_entry.fullmatch(line):
                        meth = m.group("meth")
                        impl = m.group("impl")
                        if impl == "NULL":
                            continue
                        if m.group("meth") not in objcalls:
                            objcalls[meth] = set()
                        objcalls[meth].add(impl)
                    if "}" in line:
                        in_vtable = False
                elif re_vtable_start.search(line):
                    in_vtable = True

    tmessage_handlers: set[str] | None = None
    if any(fname.endswith("lib9p/srv.c") for fname in c_fnames):
        srv_c = next(fname for fname in c_fnames if fname.endswith("lib9p/srv.c"))
        re_tmessage_handler = re.compile(
            r"^\s*\[LIB9P_TYP_T[^]]+\]\s*=\s*\(tmessage_handler\)\s*(?P<handler>\S+),\s*$"
        )
        tmessage_handlers = set()
        with open(srv_c, "r") as fh:
            for line in fh:
                line = line.rstrip()
                if m := re_tmessage_handler.fullmatch(line):
                    tmessage_handlers.add(m.group("handler"))

    lib9p_msgs: set[str] = set()
    if any(fname.endswith("lib9p/9p.c") for fname in c_fnames):
        generated_c = next(
            fname for fname in c_fnames if fname.endswith("lib9p/9p.generated.c")
        )
        re_lib9p_msg_entry = re.compile(r"^\s*_MSG_(?:[A-Z]+)\((?P<typ>\S+)\),$")
        with open(generated_c, "r") as fh:
            for line in fh:
                line = line.rstrip()
                if m := re_lib9p_msg_entry.fullmatch(line):
                    typ = m.group("typ")
                    lib9p_msgs.add(typ)

    re_call_objcall = re.compile(r"LO_CALL\((?P<obj>[^,]+), (?P<meth>[^,)]+)[,)].*")

    def sbc_indirect_callees(loc: str, line: str) -> list[str] | None:
        if "/3rd-party/" in loc:
            return None
        if m := re_call_objcall.fullmatch(line):
            if m.group("meth") in objcalls:
                return sorted(objcalls[m.group("meth")])
            return [f"__indirect_call:{m.group('obj')}.vtable->{m.group('meth')}"]
        if "trigger->cb(trigger->cb_arg)" in line:
            return [
                "alarmclock_sleep_intrhandler",
                "w5500_tcp_alarm_handler",
                "w5500_udp_alarm_handler",
            ]
        if "/chan.h:" in loc and "front->dequeue(" in line:
            return [
                "_cr_chan_dequeue",
                "_cr_select_dequeue",
            ]
        if tmessage_handlers and "/srv.c:" in loc and "tmessage_handlers[typ](" in line:
            return sorted(tmessage_handlers)
        if lib9p_msgs and "/9p.c:" in loc:
            for meth in ["validate", "unmarshal", "marshal"]:
                if line.startswith(f"tentry.{meth}("):
                    return sorted(f"{meth}_{msg}" for msg in lib9p_msgs)
        return None

    hooks_indirect_callees += [sbc_indirect_callees]

    def sbc_is_thread(name: str) -> int:
        if name.endswith("_cr") and name != "lib9p_srv_read_cr":
            if "9p" in name:
                if "read" in name:
                    return 8
                elif "write" in name:
                    return 16
            return 1
        if name == "main":
            return True
        return False

    def sbc_is_intrhandler(name: str) -> bool:
        return name in [
            "rp2040_hwtimer_intrhandler",
            "_cr_gdb_intrhandler",
            "hostclock_handle_sig_alarm",
            "hostnet_handle_sig_io",
        ]

    hooks_is_intrhandler += [sbc_is_intrhandler]

    sbc_gpio_handlers = [
        "w5500_intrhandler",
    ]

    # 1=just root directory
    # 2=just files in root directory
    # 3=just 1 level of subdirectories
    # 4=just 2 levels of subdirectories
    # ...
    #
    # TODO: Sniff this from config.h
    CONFIG_9P_SRV_MAX_DEPTH = 3

    def sbc_skip_call(chain: list[str], call: str) -> bool:
        if (
            len(chain) > 1
            and chain[-1] == "__assert_msg_fail"
            and call.endswith(":__lm_printf")
            and "__assert_msg_fail" in chain[:-1]
        ):
            return True
        if (
            len(chain) >= CONFIG_9P_SRV_MAX_DEPTH
            and "/srv.c:srv_util_pathfree" in call
            and all(
                ("/srv.c:srv_util_pathfree" in c)
                for c in chain[-CONFIG_9P_SRV_MAX_DEPTH:]
            )
        ):
            return True
        return False

    hooks_skip_call += [sbc_skip_call]

    # pico-sdk #######################################################

    if arg_pico_platform == "rp2040":

        def pico_is_intrhandler(name: str) -> bool:
            return name in [
                "gpio_default_irq_handler",
            ]

        hooks_is_intrhandler += [pico_is_intrhandler]

        def pico_indirect_callees(loc: str, line: str) -> list[str] | None:
            if "/3rd-party/pico-sdk/" not in loc or "/3rd-party/pico-sdk/lib/" in loc:
                return None
            m = re_call_other.fullmatch(line)
            call: str | None = m.group("func") if m else None

            match call:
                case "connect_internal_flash_func":
                    return ["rom_func_lookup(ROM_FUNC_CONNECT_INTERNAL_FLASH)"]
                case "flash_exit_xip_func":
                    return ["rom_func_lookup(ROM_FUNC_FLASH_EXIT_XIP)"]
                case "flash_range_erase_func":
                    return ["rom_func_lookup(ROM_FUNC_FLASH_RANGE_ERASE)"]
                case "flash_flush_cache_func":
                    return ["rom_func_lookup(ROM_FUNC_FLASH_FLUSH_CACHE)"]
                case "rom_table_lookup":
                    return ["rom_hword_as_ptr(BOOTROM_TABLE_LOOKUP_OFFSET)"]
            if "/flash.c:" in loc and "boot2_copyout" in line:
                return ["_stage2_boot"]
            if "/gpio.c:" in loc and call == "callback":
                return sbc_gpio_handlers
            if "/printf.c:" in loc:
                if call == "out":
                    return [
                        "_out_buffer",
                        "_out_null",
                        "_out_fct",
                    ]
                if "->fct(" in line:
                    return ["stdio_buffered_printer"]
            if "/stdio.c:" in loc:
                if call == "out_func":
                    return [
                        "stdio_out_chars_crlf",
                        "stdio_out_chars_no_crlf",
                    ]
                if call and (call.startswith("d->") or call.startswith("driver->")):
                    _, meth = call.split("->", 1)
                    match meth:
                        case "out_chars":
                            return ["stdio_uart_out_chars"]
                        case "out_flush":
                            return ["stdio_uart_out_flush"]
                        case "in_chars":
                            return ["stdio_uart_in_chars"]
            return None

        hooks_indirect_callees += [pico_indirect_callees]

        def pico_skip_call(chain: list[str], call: str) -> bool:
            if call == "_out_buffer" or call == "_out_fct":
                last = ""
                for pcall in chain:
                    if pcall in [
                        "__wrap_sprintf",
                        "__wrap_snprintf",
                        "__wrap_vsnprintf",
                        "vfctprintf",
                    ]:
                        last = pcall
                if last == "vfctprintf":
                    return call != "_out_fct"
                else:
                    return call == "_out_buffer"
            return False

        hooks_skip_call += [pico_skip_call]

        # src/rp2_common/hardware_divider/include/hardware/divider_helper.S
        save_div_state_and_lr = 5 * 4
        # src/rp2_common/pico_divider/divider_hardware.S
        save_div_state_and_lr_64 = 5 * 4
        all_nodes += [
            # src/rp2_common/pico_int64_ops/pico_int64_ops_aeabi.S
            synthetic_node("__aeabi_lmul", 4),
            # src/rp2_common/pico_divider/divider_hardware.S
            # s32 aliases
            synthetic_node("div_s32s32", 0, {"divmod_s32s32"}),
            synthetic_node("__aeabi_idiv", 0, {"divmod_s32s32"}),
            synthetic_node("__aeabi_idivmod", 0, {"divmod_s32s32"}),
            # s32 impl
            synthetic_node("divmod_s32s32", 0, {"divmod_s32s32_savestate"}),
            synthetic_node(
                "divmod_s32s32_savestate",
                save_div_state_and_lr,
                {"divmod_s32s32_unsafe"},
            ),
            synthetic_node("divmod_s32s32_unsafe", 2 * 4, {"__aeabi_idiv0"}),
            # u32 aliases
            synthetic_node("div_u32u32", 0, {"divmod_u32u32"}),
            synthetic_node("__aeabi_uidiv", 0, {"divmod_u32u32"}),
            synthetic_node("__aeabi_uidivmod", 0, {"divmod_u32u32"}),
            # u32 impl
            synthetic_node("divmod_u32u32", 0, {"divmod_u32u32_savestate"}),
            synthetic_node(
                "divmod_u32u32_savestate",
                save_div_state_and_lr,
                {"divmod_u32u32_unsafe"},
            ),
            synthetic_node("divmod_u32u32_unsafe", 2 * 4, {"__aeabi_idiv0"}),
            # s64 aliases
            synthetic_node("div_s64s64", 0, {"divmod_s64s64"}),
            synthetic_node("__aeabi_ldiv", 0, {"divmod_s64s64"}),
            synthetic_node("__aeabi_ldivmod", 0, {"divmod_s64s64"}),
            # s64 impl
            synthetic_node("divmod_s64s64", 0, {"divmod_s64s64_savestate"}),
            synthetic_node(
                "divmod_s64s64_savestate",
                save_div_state_and_lr_64 + (2 * 4),
                {"divmod_s64s64_unsafe"},
            ),
            synthetic_node(
                "divmod_s64s64_unsafe", 4, {"divmod_u64u64_unsafe", "__aeabi_ldiv0"}
            ),
            # u64 aliases
            synthetic_node("div_u64u64", 0, {"divmod_u64u64"}),
            synthetic_node("__aeabi_uldiv", 0, {"divmod_u64u64"}),
            synthetic_node("__aeabi_uldivmod", 0, {"divmod_u64u64"}),
            # u64 impl
            synthetic_node("divmod_u64u64", 0, {"divmod_u64u64_savestate"}),
            synthetic_node(
                "divmod_u64u64_savestate",
                save_div_state_and_lr_64 + (2 * 4),
                {"divmod_u64u64_unsafe"},
            ),
            synthetic_node(
                "divmod_u64u64_unsafe", (1 + 1 + 2 + 5 + 5 + 2) * 4, {"__aeabi_ldiv0"}
            ),
            # *_rem
            synthetic_node("divod_s64s64_rem", 2 * 4, {"divmod_s64s64"}),
            synthetic_node("divod_u64u64_rem", 2 * 4, {"divmod_u64u64"}),
            # src/rp2040/boot_stage2/boot2_${name,,}.S for name=W25Q080,
            # controlled by `#define PICO_BOOT_STAGE2_{name} 1` in
            # src/boards/include/boards/pico.h
            synthetic_node("_stage2_boot", 0),  # TODO
            # https://github.com/raspberrypi/pico-bootrom-rp2040
            synthetic_node(
                "rom_func_lookup(ROM_FUNC_CONNECT_INTERNAL_FLASH)", 0
            ),  # TODO
            synthetic_node("rom_func_lookup(ROM_FUNC_FLASH_EXIT_XIP)", 0),  # TODO
            synthetic_node("rom_func_lookup(ROM_FUNC_FLASH_FLUSH_CACHE)", 0),  # TODO
            synthetic_node("rom_hword_as_ptr(BOOTROM_TABLE_LOOKUP_OFFSET)", 0),  # TODO
        ]

    # TinyUSB device #################################################

    if any(fname.endswith("/tinyusb/src/device/usbd.c") for fname in c_fnames):
        tusb_config_fname = (
            arg_base_dir + "/cmd/sbc_harness/config/tusb_config.h"
        )  # TODO: FIXME
        re_tud_class = re.compile(
            r"^\s*#\s*define\s+(?P<k>CFG_TUD_(?:\S{3}|AUDIO|VIDEO|MIDI|VENDOR|USBTMC|DFU_RUNTIME|ECM_RNDIS))\s+(?P<v>\S+).*"
        )
        tusb_config: dict[str, bool] = {}
        with open(tusb_config_fname, "r") as fh:
            in_table = False
            for line in fh:
                line = line.rstrip()
                if m := re_tud_class.fullmatch(line):
                    k = m.group("k")
                    v = m.group("v")
                    tusb_config[k] = bool(int(v))

        usbd_fname = next(
            fname for fname in c_fnames if fname.endswith("/tinyusb/src/device/usbd.c")
        )
        tud_drivers: dict[str, set[str]] = {}
        re_tud_entry = re.compile(
            r"^\s+\.(?P<meth>\S+)\s*=\s*(?P<impl>[a-zA-Z0-9_]+)(?:,.*)?"
        )
        re_tud_if1 = re.compile(r"^\s*#\s*if (\S+)\s*")
        re_tud_if2 = re.compile(r"^\s*#\s*if (\S+)\s*\|\|\s*(\S+)\s*")
        re_tud_endif = re.compile(r"^\s*#\s*endif\s*")
        with open(usbd_fname, "r") as fh:
            in_table = False
            enabled = True
            for line in fh:
                line = line.rstrip()
                if in_table:
                    if m := re_tud_if1.fullmatch(line):
                        enabled = tusb_config[m.group(1)]
                    elif m := re_tud_if2.fullmatch(line):
                        enabled = tusb_config[m.group(1)] or tusb_config[m.group(2)]
                    elif re_tud_endif.fullmatch(line):
                        enabled = True
                    if m := re_tud_entry.fullmatch(line):
                        meth = m.group("meth")
                        impl = m.group("impl")
                        if meth == "name" or not enabled:
                            continue
                        if meth not in tud_drivers:
                            tud_drivers[meth] = set()
                        if impl != "NULL":
                            tud_drivers[meth].add(impl)
                    if line.startswith("}"):
                        in_table = False
                elif " _usbd_driver[] = {" in line:
                    in_table = True

        def tud_indirect_callees(loc: str, line: str) -> list[str] | None:
            if (
                "/tinyusb/" not in loc
                or "/tinyusb/src/host/" in loc
                or "_host.c:" in loc
            ):
                return None
            m = re_call_other.fullmatch(line)
            assert m
            call = m.group("func")
            if call == "_ctrl_xfer.complete_cb":
                return [
                    # "process_test_mode_cb",
                    "tud_vendor_control_xfer_cb",
                    *sorted(tud_drivers["control_xfer_cb"]),
                ]
            elif call.startswith("driver->"):
                return sorted(tud_drivers[call[len("driver->") :]])
            elif call == "event.func_call.func":
                # callback from usb_defer_func()
                return []

            return None

        hooks_indirect_callees += [tud_indirect_callees]

    # newlib #########################################################

    if arg_pico_platform == "rp2040":
        # This is accurate to
        # /usr/arm-none-eabi/lib/thumb/v6-m/nofp/libg.a as of
        # Parabola's arm-none-eabi-newlib 4.5.0.20241231-1.
        all_nodes += [
            # malloc
            synthetic_node("free", 8, {"_free_r"}),
            synthetic_node("malloc", 8, {"_malloc_r"}),
            synthetic_node("realloc", 8, {"_realloc_r"}),
            synthetic_node("aligned_alloc", 8, {"_memalign_r"}),
            synthetic_node("reallocarray", 24, {"realloc", "__errno"}),
            synthetic_node("_free_r", 0),  # TODO
            synthetic_node("_malloc_r", 0),  # TODO
            synthetic_node("_realloc_r", 0),  # TODO
            synthetic_node("_memalign_r", 0),  # TODO
            # execution
            synthetic_node("raise", 16, {"_getpid_r"}),
            synthetic_node("abort", 8, {"raise", "_exit"}),
            synthetic_node("longjmp", 0),
            synthetic_node("setjmp", 0),
            # <strings.h>
            synthetic_node("memcmp", 12),
            synthetic_node("memcpy", 28),
            synthetic_node("memset", 20),
            synthetic_node("strcmp", 16),
            synthetic_node("strlen", 8),
            synthetic_node("strncpy", 16),
            synthetic_node("strnlen", 8),
            # other
            synthetic_node("__errno", 0),
            synthetic_node("_getpid_r", 8, {"_getpid"}),
            synthetic_node("random", 8),
        ]

    # libgcc #########################################################

    if arg_pico_platform == "rp2040":
        # This is accurate to
        # /usr/lib/gcc/arm-none-eabi/14.2.0/thumb/v6-m/nofp/libgcc.a
        # as of Parabola's arm-none-eabi-gcc 14.2.0-1.
        all_nodes += [
            synthetic_node("__aeabi_idiv0", 0),
            synthetic_node("__aeabi_ldiv0", 0),
            synthetic_node("__aeabi_llsr", 0),
        ]

    # Tie it all together ############################################

    def thread_filter(name: str) -> int:
        return sbc_is_thread(name)

    def intrhandler_filter(name: str) -> int:
        name = name.rsplit(":", 1)[-1]
        for hook in hooks_is_intrhandler:
            if hook(name):
                return 1
        return 0

    def misc_filter(name: str) -> int:
        if name.endswith(":__lm_printf") or name == "__assert_msg_fail":
            return 1
        return 0

    def location_xform(loc: str) -> str:
        if not loc.startswith("/"):
            return loc
        parts = loc.split(":", 1)
        parts[0] = "./" + os.path.relpath(parts[0], arg_base_dir)
        return ":".join(parts)

    def indirect_callees(elem: VCGElem) -> list[str]:
        loc = elem.attrs.get("label", "")
        line = read_source(loc)

        for hook in hooks_indirect_callees:
            ret = hook(loc, line)
            if ret is not None:
                return ret

        placeholder = "__indirect_call"
        if m := re_call_other.fullmatch(line):
            placeholder += ":" + m.group("func")
        placeholder += " at " + location_xform(elem.attrs.get("label", ""))
        return [placeholder]

    def skip_call(chain: list[str], call: str) -> bool:
        for hook in hooks_skip_call:
            if hook(chain, call):
                return True
        return False

    analyze(
        ci_fnames=arg_ci_fnames,
        extra_nodes=all_nodes,
        app_func_filters={
            "Threads": thread_filter,
            "Interrupt handlers": intrhandler_filter,
            "Misc": misc_filter,
        },
        app_location_xform=location_xform,
        app_indirect_callees=indirect_callees,
        app_skip_call=skip_call,
        cfg_max_call_depth=100,
    )


if __name__ == "__main__":
    pico_platform = sys.argv[1]
    base_dir = sys.argv[2]
    fnames = sys.argv[3:]

    re_suffix = re.compile(r"\.c\.o(bj)?$")
    ci_fnames = [
        re_suffix.sub(".c.ci", fname) for fname in fnames if re_suffix.search(fname)
    ]

    c_fnames = [fname for fname in fnames if fname.endswith(".c")]

    main(
        arg_pico_platform=pico_platform,
        arg_base_dir=base_dir,
        arg_ci_fnames=ci_fnames,
        arg_c_fnames=c_fnames,
    )