# lib9p/idl/__init__.py - A parser for .9p specification files.
#
# Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import enum
import os.path
import re
from typing import Callable, Literal, TypeAlias, TypeVar, cast

__all__ = [
    # entrypoint
    "Parser",
    # types
    "Type",
    "Primitive",
    "Number",
    *["Bitfield", "BitfieldVal"],
    *["Struct", "StructMember", "Expr", "ExprOp", "ExprSym", "ExprLit"],
    "Message",
]

# The syntax that this parses is described in `./0000-README.md`.

# Types ########################################################################


class Primitive(enum.Enum):
    u8 = 1
    u16 = 2
    u32 = 4
    u64 = 8

    @property
    def in_versions(self) -> set[str]:
        return set()

    @property
    def name(self) -> str:
        return str(self.value)

    @property
    def static_size(self) -> int:
        return self.value


class Number:
    name: str
    in_versions: set[str]

    prim: Primitive

    def __init__(self) -> None:
        self.in_versions = set()

    @property
    def static_size(self) -> int:
        return self.prim.static_size


class BitfieldVal:
    name: str
    in_versions: set[str]

    val: str

    def __init__(self) -> None:
        self.in_versions = set()


class Bitfield:
    name: str
    in_versions: set[str]

    prim: Primitive

    bits: list[str]  # bitnames
    names: dict[str, BitfieldVal]  # bits *and* aliases

    def __init__(self) -> None:
        self.in_versions = set()
        self.names = {}

    @property
    def static_size(self) -> int:
        return self.prim.static_size

    def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool:
        """Return whether the given bit is valid in the given protocol
        version.

        """
        bitname = self.bits[bit] if isinstance(bit, int) else bit
        assert bitname in self.bits
        if not bitname:
            return False
        if bitname.startswith("_"):
            return False
        if ver and (ver not in self.names[bitname].in_versions):
            return False
        return True


class ExprLit:
    val: int

    def __init__(self, val: int) -> None:
        self.val = val


class ExprSym:
    name: str

    def __init__(self, name: str) -> None:
        self.name = name


class ExprOp:
    op: Literal["-", "+"]

    def __init__(self, op: Literal["-", "+"]) -> None:
        self.op = op


class Expr:
    tokens: list[ExprLit | ExprSym | ExprOp]

    def __init__(self) -> None:
        self.tokens = []

    def __bool__(self) -> bool:
        return len(self.tokens) > 0


class StructMember:
    # from left-to-right when parsing
    cnt: str | None = None
    name: str
    typ: "Type"
    max: Expr
    val: Expr

    in_versions: set[str]

    @property
    def static_size(self) -> int | None:
        if self.cnt:
            return None
        return self.typ.static_size


class Struct:
    name: str
    in_versions: set[str]

    members: list[StructMember]

    def __init__(self) -> None:
        self.in_versions = set()

    @property
    def static_size(self) -> int | None:
        size = 0
        for member in self.members:
            msize = member.static_size
            if msize is None:
                return None
            size += msize
        return size


class Message(Struct):
    @property
    def msgid(self) -> int:
        assert len(self.members) >= 3
        assert self.members[1].name == "typ"
        assert self.members[1].static_size == 1
        assert self.members[1].val
        assert len(self.members[1].val.tokens) == 1
        assert isinstance(self.members[1].val.tokens[0], ExprLit)
        return self.members[1].val.tokens[0].val


Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message
# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13
T = TypeVar("T", Number, Bitfield, Struct, Message)

# Parse ########################################################################

re_priname = "(?:1|2|4|8)"  # primitive names
re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)"  # "symbol" names; most *.9p-defined names
re_impname = r"(?:\*|" + re_symname + ")"  # names we can import
re_msgname = r"(?:[TR][a-zA-Z_0-9]*)"  # names a message can be

re_memtype = f"(?:{re_symname}|{re_priname})"  # typenames that a struct member can be

re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)"

re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})"
re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"

re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?"


def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None:
    spec = spec.strip()

    bit: int | None
    val: BitfieldVal
    if m := re.fullmatch(re_bitspec_bit, spec):
        bit = int(m.group("bit"))
        name = m.group("name")

        val = BitfieldVal()
        val.name = name
        val.val = f"1<<{bit}"
        val.in_versions.add(ver)

        if bit < 0 or bit >= len(bf.bits):
            raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
        if bf.bits[bit]:
            raise ValueError(f"{bf.name}: bit {bit} already assigned")
        bf.bits[bit] = val.name
    elif m := re.fullmatch(re_bitspec_alias, spec):
        name = m.group("name")
        valstr = m.group("val")

        val = BitfieldVal()
        val.name = name
        val.val = valstr
        val.in_versions.add(ver)
    else:
        raise SyntaxError(f"invalid bitfield spec {repr(spec)}")

    if val.name in bf.names:
        raise ValueError(f"{bf.name}: name {val.name} already assigned")
    bf.names[val.name] = val


def parse_expr(expr: str) -> Expr:
    assert re.fullmatch(re_expr, expr)
    ret = Expr()
    for tok in re.split("([-+])", expr):
        if tok == "-" or tok == "+":
            # I, for the life of me, do not understand why I need this
            # cast() to keep mypy happy.
            ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))]
        elif re.fullmatch("[0-9]+", tok):
            ret.tokens += [ExprLit(int(tok))]
        else:
            ret.tokens += [ExprSym(tok)]
    return ret


def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None:
    for spec in specs.split():
        m = re.fullmatch(re_memberspec, spec)
        if not m:
            raise SyntaxError(f"invalid member spec {repr(spec)}")

        member = StructMember()
        member.in_versions = {ver}

        member.name = m.group("name")
        if any(x.name == member.name for x in struct.members):
            raise ValueError(f"duplicate member name {repr(member.name)}")

        if m.group("typ") not in env:
            raise NameError(f"Unknown type {repr(m.group(2))}")
        member.typ = env[m.group("typ")]

        if cnt := m.group("cnt"):
            if len(struct.members) == 0 or struct.members[-1].name != cnt:
                raise ValueError(f"list count must be previous item: {repr(cnt)}")
            if not isinstance(struct.members[-1].typ, Primitive):
                raise ValueError(f"list count must be an integer type: {repr(cnt)}")
            member.cnt = cnt

        if maxstr := m.group("max"):
            if (not isinstance(member.typ, Primitive)) or member.cnt:
                raise ValueError("',max=' may only be specified on a non-repeated atom")
            member.max = parse_expr(maxstr)
        else:
            member.max = Expr()

        if valstr := m.group("val"):
            if (not isinstance(member.typ, Primitive)) or member.cnt:
                raise ValueError("',val=' may only be specified on a non-repeated atom")
            member.val = parse_expr(valstr)
        else:
            member.val = Expr()

        struct.members += [member]


def re_string(grpname: str) -> str:
    return f'"(?P<{grpname}>[^"]*)"'


re_line_version = f"version\\s+{re_string('version')}"
re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)"
re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
re_line_bitfield_ = (
    f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}"
)
re_line_struct = (
    f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
)
re_line_msg = (
    f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
)
re_line_cont = f"\\s+{re_string('specs')}"  # could be bitfield/struct/msg


def parse_file(
    filename: str, get_include: Callable[[str], tuple[str, list[Type]]]
) -> tuple[str, list[Type]]:
    version: str | None = None
    env: dict[str, Type] = {
        "1": Primitive.u8,
        "2": Primitive.u16,
        "4": Primitive.u32,
        "8": Primitive.u64,
    }

    def get_type(name: str, tc: type[T]) -> T:
        nonlocal env
        if name not in env:
            raise NameError(f"Unknown type {repr(name)}")
        ret = env[name]
        if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__):
            raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}")
        return ret

    with open(filename, "r") as fh:
        prev: Type | None = None
        for line in fh:
            line = line.split("#", 1)[0].rstrip()
            if not line:
                continue
            if m := re.fullmatch(re_line_version, line):
                if version:
                    raise SyntaxError("must have exactly 1 version line")
                version = m.group("version")
                continue
            if not version:
                raise SyntaxError("must have exactly 1 version line")

            if m := re.fullmatch(re_line_import, line):
                other_version, other_typs = get_include(m.group("file"))
                for symname in m.group("syms").split(sep=","):
                    symname = symname.strip()
                    for typ in other_typs:
                        if typ.name == symname or symname == "*":
                            match typ:
                                case Primitive():
                                    pass
                                case Number():
                                    typ.in_versions.add(version)
                                case Bitfield():
                                    typ.in_versions.add(version)
                                    for val in typ.names.values():
                                        if other_version in val.in_versions:
                                            val.in_versions.add(version)
                                case Struct():  # and Message()
                                    typ.in_versions.add(version)
                                    for member in typ.members:
                                        if other_version in member.in_versions:
                                            member.in_versions.add(version)
                            env[typ.name] = typ
            elif m := re.fullmatch(re_line_num, line):
                num = Number()
                num.name = m.group("name")
                num.in_versions.add(version)

                prim = env[m.group("prim")]
                assert isinstance(prim, Primitive)
                num.prim = prim

                env[num.name] = num
                prev = num
            elif m := re.fullmatch(re_line_bitfield, line):
                bf = Bitfield()
                bf.name = m.group("name")
                bf.in_versions.add(version)

                prim = env[m.group("prim")]
                assert isinstance(prim, Primitive)
                bf.prim = prim

                bf.bits = (prim.static_size * 8) * [""]

                env[bf.name] = bf
                prev = bf
            elif m := re.fullmatch(re_line_bitfield_, line):
                bf = get_type(m.group("name"), Bitfield)
                parse_bitspec(version, bf, m.group("member"))

                prev = bf
            elif m := re.fullmatch(re_line_struct, line):
                match m.group("op"):
                    case "=":
                        struct = Struct()
                        struct.name = m.group("name")
                        struct.in_versions.add(version)
                        struct.members = []
                        parse_members(version, env, struct, m.group("members"))

                        env[struct.name] = struct
                        prev = struct
                    case "+=":
                        struct = get_type(m.group("name"), Struct)
                        parse_members(version, env, struct, m.group("members"))

                        prev = struct
            elif m := re.fullmatch(re_line_msg, line):
                match m.group("op"):
                    case "=":
                        msg = Message()
                        msg.name = m.group("name")
                        msg.in_versions.add(version)
                        msg.members = []
                        parse_members(version, env, msg, m.group("members"))

                        env[msg.name] = msg
                        prev = msg
                    case "+=":
                        msg = get_type(m.group("name"), Message)
                        parse_members(version, env, msg, m.group("members"))

                        prev = msg
            elif m := re.fullmatch(re_line_cont, line):
                match prev:
                    case Bitfield():
                        parse_bitspec(version, prev, m.group("specs"))
                    case Struct():  # and Message()
                        parse_members(version, env, prev, m.group("specs"))
                    case _:
                        raise SyntaxError(
                            "continuation line must come after a bitfield, struct, or msg line"
                        )
            else:
                raise SyntaxError(f"invalid line {repr(line)}")
    if not version:
        raise SyntaxError("must have exactly 1 version line")

    typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]

    for typ in [typ for typ in typs if isinstance(typ, Struct)]:
        valid_syms = ["end", *["&" + m.name for m in typ.members]]
        for member in typ.members:
            for tok in [*member.max.tokens, *member.val.tokens]:
                if isinstance(tok, ExprSym) and tok.name not in valid_syms:
                    raise ValueError(
                        f"{typ.name}.{member.name}: invalid sym: {tok.name}"
                    )

    return version, typs


# Filesystem ###################################################################


class Parser:
    cache: dict[str, tuple[str, list[Type]]] = {}

    def parse_file(self, filename: str) -> tuple[str, list[Type]]:
        filename = os.path.normpath(filename)
        if filename not in self.cache:

            def get_include(other_filename: str) -> tuple[str, list[Type]]:
                return self.parse_file(os.path.join(filename, "..", other_filename))

            self.cache[filename] = parse_file(filename, get_include)
        return self.cache[filename]

    def all(self) -> tuple[set[str], list[Type]]:
        ret_versions: set[str] = set()
        ret_typs: dict[str, Type] = {}
        for version, typs in self.cache.values():
            if version in ret_versions:
                raise ValueError(f"duplicate protocol version {repr(version)}")
            ret_versions.add(version)
            for typ in typs:
                if typ.name in ret_typs:
                    if typ != ret_typs[typ.name]:
                        raise ValueError(f"duplicate type name {repr(typ.name)}")
                else:
                    ret_typs[typ.name] = typ
        return ret_versions, list(ret_typs.values())