#!/usr/bin/env python3
# build-aux/stack.c.gen - Analyze stack sizes for compiled objects
#
# Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import re
import sys
import typing

################################################################################
#
# Parse the "VCG" language
#
# https://www.rw.cdl.uni-saarland.de/people/sander/private/html/gsvcg1.html
#
# The formal syntax is found at
# ftp://ftp.cs.uni-sb.de/pub/graphics/vcg/vcg.tgz `doc/grammar.txt`.


class VCGElem:
    typ: str
    lineno: int
    attrs: dict[str, str]


def parse_vcg(reader: typing.TextIO) -> typing.Iterator[VCGElem]:
    re_beg = re.compile(r"(edge|node):\s*\{\s*")
    _re_tok = r"[a-zA-Z_][a-zA-Z0-9_]*"
    _re_str = r'"(?:[^\"]|\\.)*"'
    re_attr = re.compile(
        "(" + _re_tok + r")\s*:\s*(" + _re_tok + "|" + _re_str + r")\s*"
    )
    re_end = re.compile(r"\}\s*$")
    re_skip = re.compile(r"(graph:\s*\{\s*title\s*:\s*" + _re_str + r"\s*|\})\s*")
    re_esc = re.compile(r"\\.")

    for lineno, line in enumerate(reader):
        pos = 0

        def _raise(msg: str) -> typing.NoReturn:
            nonlocal lineno
            nonlocal line
            nonlocal pos
            e = SyntaxError(msg)
            e.lineno = lineno
            e.offset = pos
            e.text = line
            raise e

        if re_skip.fullmatch(line):
            continue

        elem = VCGElem()
        elem.lineno = lineno

        m = re_beg.match(line, pos=pos)
        if not m:
            _raise("does not look like a VCG line")
        elem.typ = m.group(1)
        pos = m.end()

        elem.attrs = {}
        while True:
            if re_end.match(line, pos=pos):
                break
            m = re_attr.match(line, pos=pos)
            if not m:
                _raise("unexpected character")
            k = m.group(1)
            v = m.group(2)
            if k in elem.attrs:
                _raise(f"duplicate key: {repr(k)}")
            if v.startswith('"'):

                def unesc(esc: re.Match[str]) -> str:
                    match esc.group(0)[1:]:
                        case "n":
                            return "\n"
                        case '"':
                            return '"'
                        case "\\":
                            return "\\"
                        case _:
                            _raise(f"invalid escape code {repr(esc.group(0))}")

                v = re_esc.sub(unesc, v[1:-1])
            elem.attrs[k] = v
            pos = m.end()

        yield elem


################################################################################
# Main application


class Node:
    # from .title (`static` and `__weak` functions are prefixed with
    # the compilation unit .c file.  For static functions that's fine,
    # but we'll have to handle it specially for __weak.).
    funcname: str
    # .label is "{funcname}\n{location}\n{nstatic} bytes (static}\n{ndynamic} dynamic objects"
    location: str
    nstatic: int
    ndynamic: int

    # edges with .sourcename set to this node
    calls: set[str]


re_location = re.compile(r"(?P<filename>.+):(?P<row>[0-9]+):(?P<col>[0-9]+)")


def read_source(location: str) -> str:
    m = re_location.fullmatch(location)
    if not m:
        raise ValueError(f"unexpected label value {repr(location)}")
    filename = m.group("filename")
    row = int(m.group("row")) - 1
    col = int(m.group("col")) - 1
    with open(m.group("filename"), "r") as fh:
        return fh.readlines()[row][col:].rstrip()


def main(ci_fnames: list[str]) -> None:
    re_node_label = re.compile(
        r"(?P<funcname>[^\n]+)\n"
        + r"(?P<location>[^\n]+:[0-9]+:[0-9]+)\n"
        + r"(?P<nstatic>[0-9]+) bytes \(static\)\n"
        + r"(?P<ndynamic>[0-9]+) dynamic objects",
        flags=re.MULTILINE,
    )
    re_call_vcall = re.compile(r"VCALL\((?P<obj>[^,]+), (?P<meth>[^,)]+)[,)].*")
    re_call_other = re.compile(r"(?P<func>[^(]+)\(.*")

    graph: dict[str, Node] = dict()
    qualified: dict[str, set[str]] = dict()

    def handle_elem(elem: VCGElem) -> None:
        match elem.typ:
            case "node":
                node = Node()
                node.calls = set()
                skip = False
                for k, v in elem.attrs.items():
                    match k:
                        case "title":
                            node.funcname = v
                        case "label":
                            if elem.attrs.get("shape", "") != "ellipse":
                                m = re_node_label.fullmatch(v)
                                if not m:
                                    raise ValueError(
                                        f"unexpected label value {repr(v)}"
                                    )
                                node.location = m.group("location")
                                node.nstatic = int(m.group("nstatic"))
                                node.ndynamic = int(m.group("ndynamic"))
                        case "shape":
                            if v != "ellipse":
                                raise ValueError(f"unexpected shape value {repr(v)}")
                            skip = True
                        case _:
                            raise ValueError(f"unknown edge key {repr(k)}")
                if not skip:
                    if node.funcname in graph:
                        raise ValueError(f"duplicate node {repr(node.funcname)}")
                    graph[node.funcname] = node
                    if ":" in node.funcname:
                        _, shortname = node.funcname.rsplit(":", 1)
                        if shortname not in qualified:
                            qualified[shortname] = set()
                        qualified[shortname].add(node.funcname)
            case "edge":
                caller: str | None = None
                callee: str | None = None
                for k, v in elem.attrs.items():
                    match k:
                        case "sourcename":
                            caller = v
                        case "targetname":
                            callee = v
                        case "label":
                            pass
                        case _:
                            raise ValueError(f"unknown edge key {repr(k)}")
                if caller is None or callee is None:
                    raise ValueError(f"incomplete edge: {repr(elem.attrs)}")
                if caller not in graph:
                    raise ValueError(f"unknown caller: {caller}")
                if callee == "__indirect_call":
                    callstr = read_source(elem.attrs.get("label", ""))
                    if m := re_call_vcall.fullmatch(callstr):
                        callee += f":{m.group('obj')}->vtable->{m.group('meth')}"
                    elif m := re_call_other.fullmatch(callstr):
                        callee += f":{m.group('func')}"
                    else:
                        callee += f':{elem.attrs.get("label", "")}'
                graph[caller].calls.add(callee)
            case _:
                raise ValueError(f"unknown elem type {repr(elem.typ)}")

    for ci_fname in ci_fnames:
        with open(ci_fname, "r") as fh:
            for elem in parse_vcg(fh):
                handle_elem(elem)

    missing: set[str] = set()
    cycles: set[str] = set()

    print("/*")

    dbg = False

    def nstatic(funcname: str, chain: list[str] = []) -> int:
        nonlocal dbg
        if funcname not in graph:
            if f"__wrap_{funcname}" in graph:
                # Handle `ld --wrap` functions
                funcname = f"__wrap_{funcname}"
            elif funcname in qualified and len(qualified[funcname]) == 1:
                # Handle `__weak` functions
                funcname = sorted(qualified[funcname])[0]
            else:
                missing.add(funcname)
                return 0
        if funcname in chain:
            if "__assert_msg_fail" in chain:
                if funcname == "__wrap_printf":
                    return 0
                pass
            else:
                cycles.add(f"{chain[chain.index(funcname):] + [funcname]}")
                return 9999999
        node = graph[funcname]
        if dbg:
            print(f"//dbg: {funcname}\t{node.nstatic}")
        return node.nstatic + max(
            [0, *[nstatic(call, chain + [funcname]) for call in node.calls]]
        )

    def thread_filter(name: str) -> bool:
        return name.endswith("_cr") or name == "main"

    namelen = max(len(name) for name in graph if thread_filter(name))
    numlen = max(len(str(nstatic(name))) for name in graph if name.endswith("_cr"))
    print(("=" * namelen) + " " + "=" * numlen)

    for funcname in graph:
        if thread_filter(funcname):
            # dbg = "dhcp" in funcname
            print(f"{funcname.ljust(namelen)} {str(nstatic(funcname)).rjust(numlen)}")

    print(("=" * namelen) + " " + "=" * numlen)

    for funcname in sorted(missing):
        print(f"warning: missing: {funcname}")
    for cycle in sorted(cycles):
        print(f"warning: cycle: {cycle}")

    print("*/")


if __name__ == "__main__":
    re_suffix = re.compile(r"\.c\.o(bj)?$")
    main(
        [
            re_suffix.sub(".c.ci", fname)
            for fname in sys.argv[1:]
            if re_suffix.search(fname)
        ]
    )