# gdb-helpers/libcr.py - GDB helpers for libcr.
#
# Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import contextlib
import time
import typing

import gdb
import gdb.unwinder

# GDB helpers ##################################################################


class _gdb_Locus(typing.Protocol):
    @property
    def frame_unwinders(self) -> list["gdb._Unwinder"]: ...


def gdb_unregister_unwinder(
    locus: gdb.Objfile | gdb.Progspace | None, unwinder: "gdb._Unwinder"
) -> None:
    _locus: _gdb_Locus = typing.cast(_gdb_Locus, gdb) if locus is None else locus
    _locus.frame_unwinders.remove(unwinder)
    gdb.invalidate_cached_frames()


def gdb_is_on_os() -> bool:
    try:
        gdb.execute("info proc", to_string=True)
        return True
    except gdb.error:
        return False


class gdb_JmpBuf:
    """Our own in-Python GDB-specific implementation of `jmp_buf`"""

    level: int
    registers: dict[str, str]


def gdb_setjmp() -> gdb_JmpBuf:
    """Our own in-Python GDB-specific implementation of `setjmp()`"""
    buf = gdb_JmpBuf()
    buf.level = gdb.selected_frame().level()
    gdb.execute("select-frame level 0")
    buf.registers = {}
    for line in gdb.execute("info registers", to_string=True).split("\n"):
        words = line.split(maxsplit=2)
        if len(words) < 2:
            continue
        buf.registers[words[0]] = words[1]
    gdb.execute(f"select-frame level {buf.level}")
    return buf


def gdb_longjmp(buf: gdb_JmpBuf) -> None:
    """Our own in-Python GDB-specific implementation of `longjmp()`"""

    gdb.execute("select-frame level 0")

    if (
        ("sp" in buf.registers)
        and ("msp" in buf.registers)
        and ("psp" in buf.registers)
        and ("control" in buf.registers)
    ):
        # On ARM, 'sp' is an alias for either 'msp' or 'psp'
        # (depending on 'control'&(1<<1)).  We must set all 3 before
        # fussing with 'xPSR' or frames, or GDB will get upset at us
        # about "Invalid state".
        gdb.execute(f"set $sp = {buf.registers['sp']}", to_string=True)
        gdb.execute(f"set $msp = {buf.registers['msp']}")
        gdb.execute(f"set $psp = {buf.registers['psp']}")

    for reg, val in buf.registers.items():
        gdb.execute(f"set ${reg} = {val}")
    gdb.invalidate_cached_frames()

    gdb.execute(f"select-frame level {buf.level}")


# Core libcr functionality #####################################################


class CrGlobals:
    coroutines: list["CrCoroutine"]
    _breakpoint: "CrBreakpoint"
    _known_threads: set[gdb.InferiorThread]

    def __init__(self) -> None:
        num = int(
            gdb.parse_and_eval("sizeof(coroutine_table)/sizeof(coroutine_table[0])")
        )

        self.coroutines = [CrCoroutine(self, i + 1) for i in range(num)]

        self._breakpoint = CrBreakpoint()
        self._breakpoint.enabled = False

        self._known_threads = set()

        gdb.events.cont.connect(self._on_cont)

    def delete(self) -> None:
        self.coroutines = []
        self._breakpoint.delete()
        gdb.events.cont.disconnect(self._on_cont)

    def readjmp(self, env_ptr_expr: str) -> gdb_JmpBuf:
        self._breakpoint.enabled = True
        gdb.execute(f"call (void)cr_gdb_readjmp({env_ptr_expr})")
        self._breakpoint.enabled = False
        if gdb_is_on_os():
            gdb.execute("queue-signal SIGWINCH")
        return self._breakpoint.env

    def _on_cont(self, event: gdb.Event) -> None:
        cur_threads = set(gdb.selected_inferior().threads())
        if cur_threads - self._known_threads:
            # Ignore thread creation events.
            self._known_threads = cur_threads
            return
        if self.coroutine_running:
            if not self.coroutine_running.is_selected():
                if True:  # https://sourceware.org/bugzilla/show_bug.cgi?id=32428
                    print("Must return to running coroutine before continuing.")
                    print("Hit ^C twice then run:")
                    print(f"  cr select {self.coroutine_running.id}")
                    while True:
                        time.sleep(1)
                assert self.coroutine_running._cont_env
                gdb_longjmp(self.coroutine_running._cont_env)
        for cr in self.coroutines:
            cr._cont_env = None

    def is_valid_cid(self, cid: int) -> bool:
        return 0 < cid and cid <= len(self.coroutines)

    @property
    def coroutine_running(self) -> "CrCoroutine | None":
        cid = int(gdb.parse_and_eval("coroutine_running"))
        if not self.is_valid_cid(cid):
            return None
        return self.coroutines[cid - 1]

    @property
    def coroutine_selected(self) -> "CrCoroutine | None":
        for cr in self.coroutines:
            if cr.is_selected():
                return cr
        return None

    @property
    def CR_NONE(self) -> gdb.Value:
        return gdb.parse_and_eval("CR_NONE")

    @property
    def CR_RUNNING(self) -> gdb.Value:
        return gdb.parse_and_eval("CR_RUNNING")


class CrBreakpointUnwinder(gdb.unwinder.Unwinder):
    """Used to temporarily disable unwinding so that
    gdb/breakpoint.c:check_longjmp_breakpoint_for_call_dummy() doesn't
    prematurely garbage collect the `call`-dummy-frame.

    """

    def __init__(self) -> None:
        super().__init__("cr_breakpoint_unwinder")

    # The .pyi is wrong; it says `Frame` instead of `PendingFrame`.
    def __call__(self, pending_frame: gdb.PendingFrame) -> gdb.UnwindInfo | None:
        # Stop unwinding with stop_reason=UNWIND_NO_SAVED_PC by
        # returning an UnwindInfo that doesn't have
        # `.add_saved_register("pc", ...)`.
        return pending_frame.create_unwind_info(
            gdb.unwinder.FrameId(
                sp=pending_frame.read_register("sp"),
                pc=pending_frame.read_register("pc"),
            )
        )


class CrBreakpoint(gdb.Breakpoint):
    env: gdb_JmpBuf
    _unwinder_locus: gdb.Objfile
    _unwinder: CrBreakpointUnwinder

    def __init__(self) -> None:
        self.env = gdb_JmpBuf()

        self._unwinder = CrBreakpointUnwinder()
        readjmp_sym = gdb.lookup_global_symbol("cr_gdb_readjmp")
        assert readjmp_sym
        self._unwinder_locus = readjmp_sym.symtab.objfile
        gdb.unwinder.register_unwinder(self._unwinder_locus, self._unwinder, True)
        self._unwinder.enabled = False

        super().__init__(
            function="cr_gdb_breakpoint", type=gdb.BP_BREAKPOINT, internal=True
        )

    @property
    def enabled(self) -> bool:
        return super().enabled

    @enabled.setter
    def enabled(self, value: bool) -> None:
        self._unwinder.enabled = value
        gdb.Breakpoint.enabled.__set__(self, value)  # type: ignore

    def stop(self) -> bool:
        assert self._unwinder.enabled
        self._unwinder.enabled = False
        self.env = gdb_setjmp()
        self._unwinder.enabled = True
        return False  # don't stop

    def delete(self) -> None:
        gdb_unregister_unwinder(self._unwinder_locus, self._unwinder)
        super().delete()


def cr_select_top_frame() -> None:
    gdb.execute("select-frame level 0")
    base_frame = gdb.selected_frame()
    while True:
        fn = gdb.selected_frame().name()
        if fn and (fn.startswith("cr_") or fn.startswith("_cr_")):
            older = gdb.selected_frame().older()
            if not older:
                base_frame.select()
                break
            older.select()
        else:
            break


class CrCoroutine:
    cr_globals: CrGlobals
    cid: int
    _cont_env: gdb_JmpBuf | None

    def __init__(self, cr_globals: CrGlobals, cid: int) -> None:
        self.cr_globals = cr_globals
        self.cid = cid
        self._cont_env = None

    @property
    def id(self) -> int:
        return self.cid

    @property
    def state(self) -> gdb.Value:
        return gdb.parse_and_eval(f"coroutine_table[{self.cid-1}].state")

    @property
    def name(self) -> str:
        bs: list[int] = [0] * int(gdb.parse_and_eval("sizeof(coroutine_table[0].name)"))
        for i, _ in enumerate(bs):
            bs[i] = int(gdb.parse_and_eval(f"coroutine_table[{self.cid-1}].name[{i}]"))
        return bytes(bs).decode("UTF-8").split("\x00", maxsplit=1)[0]

    def is_selected(self) -> bool:
        sp = int(gdb.parse_and_eval("$sp"))
        lo = int(gdb.parse_and_eval(f"coroutine_table[{self.id-1}].stack"))
        hi = lo + int(gdb.parse_and_eval(f"coroutine_table[{self.id-1}].stack_size"))
        return lo <= sp and sp < hi

    def select(self, level: int = -1) -> None:
        if self.cr_globals.coroutine_selected:
            self.cr_globals.coroutine_selected._cont_env = gdb_setjmp()

        if self._cont_env:
            gdb_longjmp(self._cont_env)
        else:
            env: gdb_JmpBuf
            if self == self.cr_globals.coroutine_running:
                assert False  # self._cont_env should have been set
            elif self.state == self.cr_globals.CR_RUNNING:
                env = self.cr_globals.readjmp("&coroutine_add_env")
            else:
                env = self.cr_globals.readjmp(f"&coroutine_table[{self.id-1}].env")
            gdb_longjmp(env)
            cr_select_top_frame()

    @contextlib.contextmanager
    def with_selected(self) -> typing.Iterator[None]:
        saved_env = gdb_setjmp()
        self.select()
        try:
            yield
        finally:
            gdb_longjmp(saved_env)


# User-facing commands #########################################################


class CrCommand(gdb.Command):
    """Use this command for libcr coroutines."""

    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals) -> None:
        self.cr_globals = cr_globals
        gdb.Command.__init__(self, "cr", gdb.COMMAND_RUNNING, gdb.COMPLETE_NONE, True)

    def invoke(self, arg: str, from_tty: bool) -> None:
        gdb.execute("help cr")


class CrListCommand(gdb.Command):
    """List libcr coroutines.
    Usage: cr list

    In the output:
    - the 'R' marker indicates the currently-running coroutine
    - the 'G' marker indicates the coroutine that GDB is viewing; this may be changed with `cr select`
    """

    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals) -> None:
        self.cr_globals = cr_globals
        gdb.Command.__init__(self, "cr list", gdb.COMMAND_RUNNING, gdb.COMPLETE_NONE)

    def invoke(self, arg: str, from_tty: bool) -> None:
        argv = gdb.string_to_argv(arg)
        if len(argv) != 0:
            raise gdb.GdbError(f"Usage: cr list")

        rows: list[tuple[str, str, str, str, str]] = [
            ("", "Id", "Name", "State", "Frame")
        ]
        for cr in self.cr_globals.coroutines:
            if cr.state == self.cr_globals.CR_NONE:
                continue
            rows += [
                (
                    "".join(
                        [
                            "R" if cr == self.cr_globals.coroutine_running else " ",
                            "G" if cr.is_selected() else " ",
                        ]
                    ),
                    str(cr.id),
                    repr(cr.name),
                    str(cr.state),
                    self._pretty_frame(cr, from_tty),
                )
            ]

        widths: list[int] = [
            max(len(row[col]) for row in rows) for col in range(len(rows[0]))
        ]

        def line(row: tuple[str, str, str, str, str]) -> str:

            def cell(col: int) -> str:
                return row[col].ljust(widths[col])

            return f"{cell(0)} {cell(1)}  {cell(2)}  {cell(3)}  {row[4]}"

        maxline = 0
        if screenwidth := gdb.parameter("width"):
            assert isinstance(screenwidth, int)
            maxline = max(screenwidth, len(line(rows[0])))

        for row in rows:
            l = line(row)
            if maxline and len(l) > maxline:
                l = l[:maxline]
            print(l)

    def _pretty_frame(self, cr: CrCoroutine, from_tty: bool) -> str:
        try:
            with cr.with_selected():
                saved_level = gdb.selected_frame().level()
                cr_select_top_frame()
                full = gdb.execute("frame", from_tty=from_tty, to_string=True)
                gdb.execute(f"select-frame level {saved_level}")
        except Exception as e:
            full = "#0 err: " + str(e)
        line = full.split("\n", maxsplit=1)[0]
        return line.split(maxsplit=1)[1]


class CrSelectCommand(gdb.Command):
    """Select the coroutine that GDB is viewing
    Usage: cr select COROUTINE
    COROUTINE is either a coroutine ID or coroutine name."""

    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals) -> None:
        self.cr_globals = cr_globals
        gdb.Command.__init__(self, "cr select", gdb.COMMAND_RUNNING, gdb.COMPLETE_NONE)

    def invoke(self, arg: str, from_tty: bool) -> None:
        argv = gdb.string_to_argv(arg)
        if len(argv) != 1:
            raise gdb.GdbError("Usage: cr select COROUTINE")
        cr = self._find(argv[0])
        cr.select()
        gdb.execute("frame")

    def _find(self, name: str) -> CrCoroutine:
        if name.isnumeric():
            cid = int(name)
            if (
                self.cr_globals.is_valid_cid(cid)
                and self.cr_globals.coroutines[cid - 1].state != self.cr_globals.CR_NONE
            ):
                return self.cr_globals.coroutines[cid - 1]
        crs: list[CrCoroutine] = []
        for cr in self.cr_globals.coroutines:
            if cr.state != self.cr_globals.CR_NONE and cr.name == name:
                crs += [cr]
        match len(crs):
            case 0:
                raise gdb.GdbError(f"No such coroutine: {repr(name)}")
            case 1:
                return crs[0]
            case _:
                raise gdb.GdbError(f"Ambiguous name, must use Id: {repr(name)}")


# Wire it all in ###############################################################

cr_globals: CrGlobals | None = None


def cr_initialize() -> None:
    global cr_globals
    if cr_globals:
        old = cr_globals
        new = CrGlobals()
        for i in range(min(len(old.coroutines), len(new.coroutines))):
            new.coroutines[i]._cont_env = old.coroutines[i]._cont_env
        old.delete()
        cr_globals = new
    else:
        cr_globals = CrGlobals()
    CrCommand(cr_globals)
    CrListCommand(cr_globals)
    CrSelectCommand(cr_globals)


def cr_on_new_objfile(event: gdb.Event) -> None:
    if any(
        objfile.lookup_global_symbol("cr_gdb_readjmp") for objfile in gdb.objfiles()
    ):
        print("Initializing libcr integration...")
        cr_initialize()
        gdb.events.new_objfile.disconnect(cr_on_new_objfile)


if cr_globals:
    cr_initialize()
else:
    gdb.events.new_objfile.connect(cr_on_new_objfile)