# lib9p/protogen/c_format.py - Generate C pretty-print functions
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later


import idl

from . import c9util, cutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".

# pylint: disable=unused-variable
__all__ = ["gen_c_format"]


def bf_numname(typ: idl.Bitfield, num: idl.BitNum, base: str) -> str:
    prefix = f"{typ.typname}_{num.numname}_".upper()
    return c9util.Ident(c9util.add_prefix(prefix, base))


def ext_printf(line: str) -> str:
    assert line.startswith("\t")
    assert line.endswith("\n")
    # It sucks that %v trips -Wformat and -Wformat-extra-args
    # https://gcc.gnu.org/bugzilla/show_bug.cgi?id=47781
    ret = "#pragma GCC diagnostic push\n"
    ret += '#pragma GCC diagnostic ignored "-Wformat"\n'
    ret += '#pragma GCC diagnostic ignored "-Wformat-extra-args"\n'
    ret += line
    ret += "#pragma GCC diagnostic pop\n"
    return ret


def gen_c_format(versions: set[str], typs: list[idl.UserType]) -> str:
    ret = """
/* *_format *******************************************************************/
"""
    for typ in typs:
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        ret += f"static void {c9util.basename(typ)}_format({c9util.typename(typ)} *self, struct fmt_state *state) {{\n"
        match typ:
            case idl.Number():
                if typ.vals:
                    ret += "\tswitch (*self) {\n"
                    for name in typ.vals:
                        ret += f"\tcase {c9util.Ident(c9util.add_prefix(f'{typ.typname}_'.upper(), name))}:\n"
                        ret += f'\t\tfmt_state_puts(state, "{name}");\n'
                        ret += "\t\tbreak;\n"
                    ret += "\tdefault:\n"
                    ret += f'\t\tfmt_state_printf(state, "%"PRIu{typ.static_size*8}, *self);\n'
                    ret += "\t}\n"
                else:
                    ret += f'\t\tfmt_state_printf(state, "%"PRIu{typ.static_size*8}, *self);\n'
            case idl.Bitfield():
                val = "*self"
                if typ.typname == "dm":  # SPECIAL (pretty file permissions)
                    val = f"(*self & ~(({c9util.typename(typ)})0777))"
                ret += "\tbool empty = true;\n"
                ret += "\tfmt_state_putchar(state, '(');\n"
                nums: set[str] = set()

                for bit in reversed(typ.bits):
                    match bit.cat:
                        case "UNUSED" | "USED" | "RESERVED":
                            if bit.cat == "UNUSED":
                                bitname = f"1<<{bit.num}"
                            else:
                                bitname = bit.bitname
                            ret += f"\tif ({val} & (UINT{typ.static_size*8}_C(1)<<{bit.num})) {{\n"
                            ret += "\t\tif (!empty)\n"
                            ret += "\t\t\tfmt_state_putchar(state, '|');\n"
                            ret += f'\t\tfmt_state_puts(state, "{bitname}");\n'
                            ret += "\t\tempty = false;\n"
                            ret += "\t}\n"
                        case idl.BitNum():
                            if bit.cat.numname in nums:
                                continue
                            ret += f"\tswitch ({val} & {bf_numname(typ, bit.cat, 'MASK')}) {{\n"
                            for name in bit.cat.vals:
                                ret += f"\tcase {bf_numname(typ, bit.cat, name)}:\n"
                                bitname = c9util.add_prefix(
                                    f"{bit.cat.numname}_".upper(), name
                                )
                                ret += "\t\tif (!empty)\n"
                                ret += "\t\t\tfmt_state_putchar(state, '|');\n"
                                ret += f'\t\tfmt_state_puts(state, "{bitname}");\n'
                                ret += "\t\tempty = false;\n"
                                ret += "\t\tbreak;\n"
                            ret += "\tdefault:\n"
                            ret += "\t\tif (!empty)\n"
                            ret += "\t\t\tfmt_state_putchar(state, '|');\n"
                            ret += f'\t\tfmt_state_printf(state, "%"PRIu{typ.static_size*8}, {val} & {bf_numname(typ, bit.cat, 'MASK')});\n'
                            ret += "\t\tempty = false;\n"
                            ret += "\t}\n"
                            nums.add(bit.cat.numname)
                if typ.typname == "dm":  # SPECIAL (pretty file permissions)
                    ret += "\tif (!empty)\n"
                    ret += "\t\tfmt_state_putchar(state, '|');\n"
                    ret += f'\tfmt_state_printf(state, "%#04"PRIo{typ.static_size*8}, *self & 0777);\n'
                else:
                    ret += "\tif (empty)\n"
                    ret += "\t\tfmt_state_putchar(state, '0');\n"
                ret += "\tfmt_state_putchar(state, ')');\n"
            case idl.Struct(typname="s"):  # SPECIAL(string)
                ret += ext_printf(
                    '\tfmt_state_printf(state, "%.*q", self->len, self->utf8);\n'
                )
            case idl.Struct():  # and idl.Message():
                if isinstance(typ, idl.Message):
                    ret += f'\tfmt_state_puts(state, "{typ.typname} {{");\n'
                else:
                    ret += "\tfmt_state_putchar(state, '{');\n"
                for member in typ.members:
                    if member.val:
                        continue
                    ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
                    if member.cnt:
                        if isinstance(member.cnt, int):
                            cnt_str = str(member.cnt)
                            cnt_typ = "size_t"
                        else:
                            cnt_str = f"self->{member.cnt.membname}"
                            cnt_typ = c9util.typename(member.cnt.typ)
                        if member.typ.static_size == 1:  # SPECIAL (data)
                            ret += f"\tif (is_valid_utf8_without_nul((uint8_t *)self->{member.membname}, (size_t){cnt_str})) {{\n"
                            ret += ext_printf(
                                f'\t\tfmt_state_printf(state, " {member.membname}=%.*q%s",\n'
                                f"\t\t\t(int)({cnt_str} < 50 ? {cnt_str} : 50),\n"
                                f"\t\t\t(char *)self->{member.membname},\n"
                                f'\t\t\t{cnt_str} < 50 ? "" : "...");\n'
                            )
                            ret += "\t} else {\n"
                            ret += f'\t\tfmt_state_puts(state, " {member.membname}=<bytedata>");\n'
                            ret += "\t}\n"
                            continue
                        ret += f'\tfmt_state_puts(state, " {member.membname}=[");\n'
                        ret += f"\tfor ({cnt_typ} i = 0; i < {cnt_str}; i++) {{\n"
                        ret += "\t\tif (i)\n"
                        ret += '\t\t\tfmt_state_puts(state, ", ");\n'
                        if isinstance(member.typ, idl.Primitive):
                            ret += f'\t\tfmt_state_printf(state, "%"PRIu{member.typ.static_size*8}, self->{member.membname}[i]);\n'
                        else:
                            ret += f"\t\t{c9util.basename(member.typ)}_format(&self->{member.membname}[i], state);\n"
                        ret += "\t}\n"
                        ret += '\tfmt_state_puts(state, " ]");\n'
                    else:
                        ret += f'\tfmt_state_puts(state, " {member.membname}=");\n'
                        if isinstance(member.typ, idl.Primitive):
                            ret += f'\tfmt_state_printf(state, "%"PRIu{member.typ.static_size*8}, self->{member.membname});\n'
                        else:
                            ret += f"\t{c9util.basename(member.typ)}_format(&self->{member.membname}, state);\n"
                ret += cutil.ifdef_pop(1)
                ret += '\tfmt_state_puts(state, " }");\n'
        ret += "}\n"
    ret += cutil.ifdef_pop(0)

    return ret