# build-aux/measurestack/app_output.py - Generate `*_stack.c` files
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import typing

from . import analyze
from .analyze import QName

# pylint: disable=unused-variable
__all__ = [
    "print_c",
]


def print_group(
    result: analyze.AnalyzeResult,
    location_xform: typing.Callable[[QName], str],
    grp_name: str,
) -> None:
    grp = result.groups[grp_name]
    if not grp.rows:
        print(f"= {grp_name} (empty) =")
        return

    nsum = sum(v.nstatic * v.cnt for v in grp.rows.values())
    nmax = max(v.nstatic for v in grp.rows.values())

    # Figure sizes.
    namelen = max(
        [len(location_xform(k)) for k in grp.rows.keys()] + [len(grp_name) + 4]
    )
    numlen = len(str(nsum))
    sep1 = ("=" * namelen) + " " + "=" * numlen
    sep2 = ("-" * namelen) + " " + "-" * numlen

    # Print.
    print("= " + grp_name + " " + sep1[len(grp_name) + 3 :])
    for qname, val in sorted(grp.rows.items()):
        name = location_xform(qname)
        if val.nstatic == 0:
            continue
        print(
            f"{name:<{namelen}} {val.nstatic:>{numlen}}"
            + (f" * {val.cnt}" if val.cnt != 1 else "")
        )
    print(sep2)
    print(f"{'Total':<{namelen}} {nsum:>{numlen}}")
    print(f"{'Maximum':<{namelen}} {nmax:>{numlen}}")
    print(sep1)


def next_power_of_2(x: int) -> int:
    return 1 << (x.bit_length())


def print_c(
    result: analyze.AnalyzeResult, location_xform: typing.Callable[[QName], str]
) -> None:
    print("#include <stddef.h> /* for size_t */")
    print()
    print("/*")
    print_group(result, location_xform, "Threads")
    print_group(result, location_xform, "Interrupt handlers")
    print("*/")
    intrstack = max(
        v.nstatic for v in result.groups["Interrupt handlers"].rows.values()
    )
    stack_guard_size = 16 * 2

    class CrRow(typing.NamedTuple):
        name: str
        cnt: int
        base: int
        size: int

    rows: list[CrRow] = []
    mainrow: CrRow | None = None
    for funcname, val in result.groups["Threads"].rows.items():
        name = str(funcname.base())
        base = val.nstatic
        size = base + intrstack
        if name in ["main", "_entry_point"]:
            mainrow = CrRow(name=name, cnt=1, base=base, size=size)
        else:
            size = next_power_of_2(size + stack_guard_size) - stack_guard_size
            rows.append(CrRow(name=name, cnt=val.cnt, base=base, size=size))
    namelen = max(len(r.name) for r in rows)
    baselen = max(len(str(r.base)) for r in rows)
    sizesum = sum(r.cnt * (r.size + stack_guard_size) for r in rows)
    sizelen = len(str(max(sizesum, mainrow.size if mainrow else 0)))

    def print_row(comment: bool, name: str, size: int, eqn: str | None = None) -> None:
        prefix = "const size_t CONFIG_COROUTINE_STACK_SIZE_"
        if comment:
            print(f"/* {name}".ljust(len(prefix) + namelen), end="")
        else:
            print(f"{prefix}{name:<{namelen}}", end="")
        print(f" = {size:>{sizelen}};", end="")
        if comment:
            print(" */", end="")
        elif eqn:
            print("   ", end="")
        if eqn:
            print(f" /* {eqn} */", end="")
        print()

    for row in sorted(rows):
        print_row(
            False,
            row.name,
            row.size,
            f"LM_NEXT_POWER_OF_2({row.base:>{baselen}}+{intrstack}+{stack_guard_size})-{stack_guard_size}",
        )
    print_row(True, "TOTAL (inc. stack guard)", sizesum)
    if mainrow:
        print_row(
            True,
            "MAIN/KERNEL",
            mainrow.size,
            f"                   {mainrow.base:>{baselen}}+{intrstack}",
        )
    print()
    print("/*")
    print_group(result, location_xform, "Misc")

    for funcname in sorted(result.missing):
        print(f"warning: missing: {location_xform(funcname)}")
    for funcname in sorted(result.dynamic):
        print(f"warning: dynamic-stack-usage: {location_xform(funcname)}")

    print("*/")
    print("")
    print("/*")
    print_group(result, location_xform, "Extra")
    for funcname in sorted(result.included_funcs):
        print(f"included: {location_xform(funcname)}")
    print("*/")