# gdb-helpers/libcr.py - GDB helpers for libcr.
#
# Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import contextlib
import typing

import gdb

# GDB helpers ##################################################################


def gdb_argv_to_string(argv: list[str]) -> str:
    """Reverse of gdb.string_to_argv()"""
    # TODO: This is wrong
    return " ".join(argv)


class GdbJmpBuf:
    """Our own in-Python GDB-specific implementation of `jmp_buf`"""

    frame: gdb.Frame
    registers: dict[str, str]


def gdb_setjmp() -> GdbJmpBuf:
    """Our own in-Python GDB-specific implementation of `setjmp()`"""
    buf = GdbJmpBuf()
    buf.frame = gdb.selected_frame()
    buf.registers = {}
    for line in gdb.execute("info registers", to_string=True).split("\n"):
        words = line.split(maxsplit=2)
        if len(words) < 2:
            continue
        buf.registers[words[0]] = words[1]
    return buf


def gdb_longjmp(buf: GdbJmpBuf) -> None:
    """Our own in-Python GDB-specific implementation of `longjmp()`"""
    for reg, val in buf.registers.items():
        gdb.execute(f"set ${reg} = {val}")
    buf.frame.select()


# Core libcr functionality #####################################################


class CrGlobals:
    coroutines: list["CrCoroutine"]
    breakpoints: list[gdb.Breakpoint]

    def __init__(self) -> None:
        num = int(
            gdb.parse_and_eval("sizeof(coroutine_table)/sizeof(coroutine_table[0])")
        )
        self.coroutines = [CrCoroutine(self, i + 1) for i in range(num)]
        self.breakpoints = []

    def is_valid_cid(self, cid: int) -> bool:
        return 0 < cid and cid <= len(self.coroutines)

    @property
    def coroutine_running(self) -> int:
        return int(gdb.parse_and_eval("coroutine_running"))

    @property
    def CR_NONE(self) -> gdb.Value:
        return gdb.parse_and_eval("CR_NONE")


class CrCoroutine:
    cr_globals: CrGlobals
    cid: int
    env: GdbJmpBuf | None

    def __init__(self, cr_globals: CrGlobals, cid: int) -> None:
        self.cr_globals = cr_globals
        self.cid = cid
        self.env = None

    @property
    def id(self) -> int:
        return self.cid

    @property
    def state(self) -> gdb.Value:
        return gdb.parse_and_eval(f"coroutine_table[{self.cid-1}].state")

    @property
    def name(self) -> str:
        bs: list[int] = [0] * int(gdb.parse_and_eval("sizeof(coroutine_table[0].name)"))
        for i, _ in enumerate(bs):
            bs[i] = int(gdb.parse_and_eval(f"coroutine_table[{self.cid-1}].name[{i}]"))
        return bytes(bs).decode("UTF-8").split("\x00", maxsplit=1)[0]

    @contextlib.contextmanager
    def active(self) -> typing.Iterator[None]:
        saved_env = gdb_setjmp()
        cr_env = self.env
        if self.cid == self.cr_globals.coroutine_running:
            cr_env = saved_env
        assert cr_env

        gdb_longjmp(cr_env)
        try:
            yield
        finally:
            gdb_longjmp(saved_env)


class CrYieldBreakpoint(gdb.Breakpoint):
    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals, function: str) -> None:
        self.cr_globals = cr_globals
        cr_globals.breakpoints += [self]
        super().__init__(function=function, type=gdb.BP_BREAKPOINT, internal=True)  # type: ignore

    def stop(self) -> bool:
        if self.cr_globals.is_valid_cid(self.cr_globals.coroutine_running):
            self.cr_globals.coroutines[self.cr_globals.coroutine_running - 1].env = (
                gdb_setjmp()
            )
        return False


# User-facing commands #########################################################


class CrCommand(gdb.Command):
    """Use this command for libcr coroutines."""

    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals) -> None:
        self.cr_globals = cr_globals
        gdb.Command.__init__(self, "cr", gdb.COMMAND_RUNNING, gdb.COMPLETE_NONE, True)

    def invoke(self, arg: str, from_tty: bool) -> None:
        gdb.execute("help cr")


class CrListCommand(gdb.Command):
    """List libcr coroutines.
    Usage: cr list"""

    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals) -> None:
        self.cr_globals = cr_globals
        gdb.Command.__init__(self, "cr list", gdb.COMMAND_RUNNING, gdb.COMPLETE_NONE)

    def invoke(self, arg: str, from_tty: bool) -> None:
        argv = gdb.string_to_argv(arg)
        if len(argv) != 0:
            raise gdb.GdbError(f"Usage: cr list")

        w_marker = 1
        w_id = max(len("Id"), len(str(len(self.cr_globals.coroutines))))
        w_name = max(
            len("Name"), int(gdb.parse_and_eval("sizeof(coroutine_table[0].name)"))
        )
        w_state = len("CR_INITIALIZING")
        print(
            f"  {'Id'.ljust(w_id)}  {'Name'.ljust(w_name)}  {'State'.ljust(w_state)}  Frame"
        )
        for cr in self.cr_globals.coroutines:
            if cr.state == self.cr_globals.CR_NONE:
                continue
            v_marker = "*" if cr.id == self.cr_globals.coroutine_running else ""
            v_id = str(cr.id)
            v_name = cr.name
            v_state = str(cr.state)
            v_frame = self._pretty_frame(cr, from_tty)
            print(
                f"{v_marker.ljust(w_marker)} {v_id.ljust(w_id)}  {v_name.ljust(w_name)}  {v_state.ljust(w_state)}  {v_frame}"
            )

    def _pretty_frame(self, cr: CrCoroutine, from_tty: bool) -> str:
        with cr.active():
            full = gdb.execute("frame", from_tty=from_tty, to_string=True)
        line = full.split("\n", maxsplit=1)[0]
        return line.split(maxsplit=1)[1]


class CrApplyCommand(gdb.Command):
    """Apply a GDB command to libcr coroutines.
    Usage: cr apply COROUTINE... -- COMMAND
    COROUTINE is a space-separated list of coroutine IDs or names."""

    cr_globals: CrGlobals

    def __init__(self, cr_globals: CrGlobals) -> None:
        self.cr_globals = cr_globals
        gdb.Command.__init__(self, "cr apply", gdb.COMMAND_RUNNING, gdb.COMPLETE_NONE)

    def invoke(self, arg: str, from_tty: bool) -> None:
        argv = gdb.string_to_argv(arg)
        if "--" not in argv:
            raise gdb.GdbError("Usage: cr apply COROUTINE... -- COMMAND")
        sep = argv.index("--")
        crs = argv[:sep]
        cmd = argv[sep + 1 :]
        for spec in crs:
            cr = self._find(spec)
            with cr.active():
                gdb.execute(gdb_argv_to_string(cmd), from_tty=from_tty)

    def _find(self, name: str) -> CrCoroutine:
        if name.isnumeric():
            cid = int(name)
            if (
                self.cr_globals.is_valid_cid(cid)
                and self.cr_globals.coroutines[cid - 1].state != self.cr_globals.CR_NONE
            ):
                return self.cr_globals.coroutines[cid - 1]
        crs: list[CrCoroutine] = []
        for cr in self.cr_globals.coroutines:
            if cr.state != self.cr_globals.CR_NONE and cr.name == name:
                crs += [cr]
        match len(crs):
            case 0:
                raise gdb.GdbError(f"No such coroutine: {repr(name)}")
            case 1:
                return crs[0]
            case _:
                raise gdb.GdbError(f"Ambiguous name, must use Id: {repr(name)}")


# Wire it all in ###############################################################


def cr_initialize() -> None:
    cr_globals = CrGlobals()

    CrYieldBreakpoint(cr_globals, "_cr_yield")
    CrYieldBreakpoint(cr_globals, "cr_begin")
    CrYieldBreakpoint(cr_globals, "coroutine_add_with_stack_size")

    CrCommand(cr_globals)
    CrListCommand(cr_globals)
    CrApplyCommand(cr_globals)


def cr_on_new_objfile(event: gdb.Event) -> None:
    if any(
        objfile.lookup_static_symbol("cr_plat_longjmp") for objfile in gdb.objfiles()
    ):
        print("Initializing libcr integration...")
        cr_initialize()
        gdb.events.new_objfile.disconnect(cr_on_new_objfile)


gdb.events.new_objfile.connect(cr_on_new_objfile)