# lib9p/protogen/c_validate.py - Generate C validation functions
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later


import idl

from . import c9util, cutil, idlutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# pylint: disable=unused-variable
__all__ = ["gen_c_validate"]


def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
    return bool(member.max or member.val or any(m.cnt == member for m in typ.members))


def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
    ret = """
/* validate_* *****************************************************************/

LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
\t\t/* If needed-net-size overflowed uint32_t, then
\t\t * there's no way that actual-net-size will live up to
\t\t * that.  */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\tif (ctx->net_offset > ctx->net_size)
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}

LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
\t\t/* If needed-host-size overflowed size_t, then there's
\t\t * no way that actual-net-size will live up to
\t\t * that.  */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}

LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
                                         size_t cnt,
                                         _validate_fn_t item_fn, size_t item_host_size) {
\tfor (size_t i = 0; i < cnt; i++)
\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
\t\t\treturn true;
\treturn false;
}

LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""

    for typ in idlutil.topo_sorted(typs):
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = (
            c9util.arg_unused
            if (isinstance(typ, idl.Struct) and not typ.members)
            else c9util.arg_used
        )
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"

        match typ:
            case idl.Number():
                ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
            case idl.Bitfield():
                ret += f"\t if (validate_{typ.static_size}(ctx))\n"
                ret += "\t\treturn true;\n"
                ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
                if typ.static_size == 1:
                    ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
                else:
                    ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
                ret += "\tif (val & ~mask)\n"
                ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
                ret += "\treturn false;\n"
            case idl.Struct():  # and idl.Message()
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                # Pass 1 - declare value variables
                for member in typ.members:
                    if should_save_value(typ, member):
                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
                        ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
                ret += cutil.ifdef_pop(1)

                # Pass 2 - declare offset variables
                mark_offset: set[str] = set()
                for member in typ.members:
                    for tok in [*member.max.tokens, *member.val.tokens]:
                        if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
                            if tok.symname[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
                            mark_offset.add(tok.symname[1:])

                # Pass 3 - main pass
                ret += "\treturn false\n"
                for member in typ.members:
                    ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
                    ret += "\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c9util.ver_cond(member.in_versions) + " && "
                    if member.cnt is not None:
                        if member.typ.static_size == 1:  # SPECIAL (zerocopy)
                            ret += f"_validate_size_net(ctx, {member.cnt.membname})"
                        else:
                            ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
                        if typ.typname == "s":  # SPECIAL (string)
                            ret += '\n\t    || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
                    else:
                        if should_save_value(typ, member):
                            ret += "("
                        if member.membname in mark_offset:
                            ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
                        ret += f"validate_{member.typ.typname}(ctx)"
                        if member.membname in mark_offset:
                            ret += "; })"
                        if should_save_value(typ, member):
                            nbytes = member.static_size
                            assert nbytes
                            if nbytes == 1:
                                ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
                            else:
                                ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"

                # Pass 4 - validate ,max= and ,val= constraints
                for member in typ.members:

                    def lookup_sym(sym: str) -> str:
                        match sym:
                            case "end":
                                return "ctx->net_offset"
                            case _:
                                assert sym.startswith("&")
                                return f"_{sym[1:]}_offset"

                    if member.max:
                        assert member.static_size
                        nbits = member.static_size * 8
                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
                    if member.val:
                        assert member.static_size
                        nbits = member.static_size * 8
                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'

                ret += cutil.ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += cutil.ifdef_pop(0)
    return ret