#!/usr/bin/env python
# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files
#                 defining 9P protocol variants.
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import graphlib
import os.path
import sys
import typing

sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))

import idl

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# Utilities ####################################################################

idprefix = "lib9p_"

u32max = (1 << 32) - 1
u64max = (1 << 64) - 1


def tab_ljust(s: str, width: int) -> str:
    cur = len(s.expandtabs(tabsize=8))
    if cur >= width:
        return s
    return s + " " * (width - cur)


def add_prefix(p: str, s: str) -> str:
    if s.startswith("_"):
        return "_" + p + s[1:]
    return p + s


def c_macro(full: str) -> str:
    full = full.rstrip()
    assert "\n" in full
    lines = [l.rstrip() for l in full.split("\n")]
    width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1])
    lines = [tab_ljust(l, width) for l in lines]
    return "  \\\n".join(lines).rstrip() + "\n"


def c_ver_enum(ver: str) -> str:
    return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"


def c_ver_ifdef(versions: set[str]) -> str:
    return " || ".join(
        f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
    )


def c_ver_cond(versions: set[str]) -> str:
    if len(versions) == 1:
        v = next(v for v in versions)
        return f"is_ver(ctx, {v.replace('.', '_')})"
    return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"


def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str:
    match typ:
        case idl.Primitive():
            if typ.value == 1 and parent and parent.cnt:  # SPECIAL (string)
                return "[[gnu::nonstring]] char"
            return f"uint{typ.value*8}_t"
        case idl.Number():
            return f"{idprefix}{typ.name}_t"
        case idl.Bitfield():
            return f"{idprefix}{typ.name}_t"
        case idl.Message():
            return f"struct {idprefix}msg_{typ.name}"
        case idl.Struct():
            return f"struct {idprefix}{typ.name}"
        case _:
            raise ValueError(f"not a type: {typ.__class__.__name__}")


def c_expr(expr: idl.Expr) -> str:
    ret: list[str] = []
    for tok in expr.tokens:
        match tok:
            case idl.ExprOp():
                ret.append(tok.op)
            case idl.ExprLit():
                ret.append(str(tok.val))
            case idl.ExprSym(name="end"):
                ret.append("ctx->net_offset")
            case idl.ExprSym(name="s32_max"):
                ret.append("INT32_MAX")
            case idl.ExprSym(name="s64_max"):
                ret.append("INT64_MAX")
            case idl.ExprSym():
                ret.append(f"_{tok.name[1:]}_offset")
    return " ".join(ret)


_ifdef_stack: list[str | None] = []


def ifdef_push(n: int, _newval: str) -> str:
    # Grow the stack as needed
    global _ifdef_stack
    while len(_ifdef_stack) < n:
        _ifdef_stack.append(None)

    # Set some variables
    parentval: str | None = None
    for x in _ifdef_stack[:-1]:
        if x is not None:
            parentval = x
    oldval = _ifdef_stack[-1]
    newval: str | None = _newval
    if newval == parentval:
        newval = None

    # Put newval on the stack.
    _ifdef_stack[-1] = newval

    # Build output.
    ret = ""
    if newval != oldval:
        if oldval is not None:
            ret += f"#endif /* {oldval} */\n"
        if newval is not None:
            ret += f"#if {newval}\n"
    return ret


def ifdef_pop(n: int) -> str:
    global _ifdef_stack
    ret = ""
    while len(_ifdef_stack) > n:
        if _ifdef_stack[-1] is not None:
            ret += f"#endif /* {_ifdef_stack[-1]} */\n"
        _ifdef_stack = _ifdef_stack[:-1]
    return ret


def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]:
    ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter()
    for typ in typs:
        match typ:
            case idl.Number():
                ts.add(typ)
            case idl.Bitfield():
                ts.add(typ)
            case idl.Struct():  # and idl.Message():
                deps = [
                    member.typ
                    for member in typ.members
                    if not isinstance(member.typ, idl.Primitive)
                ]
                ts.add(typ, *deps)
    return ts.static_order()


# Generate .h ##################################################################


def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
    global _ifdef_stack
    _ifdef_stack = []

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#ifndef _LIB9P_9P_H_
\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
#endif

#include <stdint.h> /* for uint{{n}}_t types */
"""

    id2typ: dict[int, idl.Message] = {}
    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
        id2typ[msg.msgid] = msg

    ret += f"""
/* config *********************************************************************/

#include "config.h"
"""
    for ver in sorted(versions):
        ret += "\n"
        ret += f"#ifndef {c_ver_ifdef({ver})}\n"
        ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n"
        ret += "#endif\n"

    ret += f"""
/* enum version ***************************************************************/

enum {idprefix}version {{
"""
    fullversions = ["unknown = 0", *sorted(versions)]
    verwidth = max(len(v) for v in fullversions)
    for ver in fullversions:
        if ver in versions:
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f"\t{c_ver_enum(ver)},"
        ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
    ret += ifdef_pop(0)
    ret += f"\t{c_ver_enum('NUM')},\n"
    ret += "};\n"

    ret += """
/* enum msg_type **************************************************************/

"""
    ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
    namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message))
    for n in range(0x100):
        if n not in id2typ:
            continue
        msg = id2typ[n]
        ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
        ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
    ret += ifdef_pop(0)
    ret += "};\n"

    ret += """
/* payload types **************************************************************/
"""

    def per_version_comment(
        typ: idl.Type, fn: typing.Callable[[idl.Type, str], str]
    ) -> str:
        lines: dict[str, str] = {}
        for version in sorted(typ.in_versions):
            lines[version] = fn(typ, version)
        if len(set(lines.values())) == 1:
            for _, line in lines.items():
                return f"/* {line} */\n"
            assert False
        else:
            ret = ""
            v_width = max(len(c_ver_enum(v)) for v in typ.in_versions)
            for version, line in lines.items():
                ret += f"/* {c_ver_enum(version).ljust(v_width)}: {line} */\n"
            return ret

    for typ in topo_sorted(typs):
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))

        def sum_size(typ: idl.Type, version: str) -> str:
            min_size = typ.min_size(version)
            max_size = typ.max_size(version)
            assert min_size <= max_size and max_size < u64max
            ret = ""
            if min_size == max_size:
                ret += f"size = {min_size:,}"
            else:
                ret += f"min_size = {min_size:,} ; max_size = {max_size:,}"
            if max_size > u32max:
                ret += " (warning: >UINT32_MAX)"
            return ret

        ret += per_version_comment(typ, sum_size)

        match typ:
            case idl.Number():
                ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
                prefix = f"{idprefix.upper()}{typ.name.upper()}_"
                namewidth = max(len(name) for name in typ.vals)
                for name, val in typ.vals.items():
                    ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
            case idl.Bitfield():
                ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
                names = [
                    typ.bits[n] or f" {n}" for n in reversed(range(0, len(typ.bits)))
                ]
                if aliases := [k for k in typ.names if k not in typ.bits]:
                    names.append("")
                    names.extend(aliases)
                prefix = f"{idprefix.upper()}{typ.name.upper()}_"
                namewidth = max(len(add_prefix(prefix, name)) for name in names)

                ret += "\n"
                for name in names:
                    if name == "":
                        ret += "\n"
                        continue

                    if name.startswith(" "):
                        vers = typ.in_versions
                        c_name = ""
                        c_val = f"1<<{name[1:]}"
                    else:
                        vers = typ.names[name].in_versions
                        c_name = add_prefix(prefix, name)
                        c_val = f"{typ.names[name].val}"

                    ret += ifdef_push(2, c_ver_ifdef(vers))

                    # It is important all of the `beg` strings have
                    # the same length.
                    end = ""
                    if name.startswith(" "):
                        beg = "/* unused"
                        end = " */"
                    elif _ifdef_stack[-1]:
                        beg = "#  define"
                    else:
                        beg = "#define  "

                    ret += f"{beg} {c_name.ljust(namewidth)}  (({c_typename(typ)})({c_val})){end}\n"
                ret += ifdef_pop(1)
            case idl.Struct():  # and idl.Message():
                ret += c_typename(typ) + " {"
                if not typ.members:
                    ret += "};\n"
                    continue
                ret += "\n"

                typewidth = max(len(c_typename(m.typ, m)) for m in typ.members)

                for member in typ.members:
                    if member.val:
                        continue
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += f"\t{c_typename(member.typ, member).ljust(typewidth)}  {'*' if member.cnt else ' '}{member.name};\n"
                ret += ifdef_pop(1)
                ret += "};\n"
    ret += ifdef_pop(0)

    return ret


# Generate .c ##################################################################


def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
    global _ifdef_stack
    _ifdef_stack = []

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#include <stdbool.h>
#include <stddef.h>   /* for size_t */
#include <inttypes.h> /* for PRI* macros */
#include <string.h>   /* for memset() */

#include <libmisc/assert.h>

#include <lib9p/9p.h>

#include "internal.h"
"""

    # utilities ################################################################
    ret += f"""
/* utilities ******************************************************************/
"""

    def used(arg: str) -> str:
        return arg

    def unused(arg: str) -> str:
        return f"LM_UNUSED({arg})"

    id2typ: dict[int, idl.Message] = {}
    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
        id2typ[msg.msgid] = msg

    def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
        ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
        for ver in ["unknown", *sorted(versions)]:
            if ver != "unknown":
                ret += ifdef_push(1, c_ver_ifdef({ver}))
            ret += f"\t[{c_ver_enum(ver)}] = {{\n"
            for n in range(*rng):
                xmsg: idl.Message | None = id2typ.get(n, None)
                if xmsg:
                    if ver == "unknown":  # SPECIAL (initialization)
                        if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
                            xmsg = None
                    else:
                        if ver not in xmsg.in_versions:
                            xmsg = None
                if xmsg:
                    ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n"
            ret += "\t},\n"
        ret += ifdef_pop(0)
        ret += "};\n"
        return ret

    for v in sorted(versions):
        ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
        ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n"
        ret += "#else\n"
        ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
        ret += "#endif\n"
    ret += "\n"
    ret += "/**\n"
    ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n"
    ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n"
    ret += " * (because `!CONFIG_9P_ENABLE_##ver`).  This is useful when `||`ing\n"
    ret += " * several version checks together.\n"
    ret += " */\n"
    ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n"

    # strings ##################################################################
    ret += f"""
/* strings ********************************************************************/

const char *_lib9p_table_ver_name[{c_ver_enum('NUM')}] = {{
"""
    for ver in ["unknown", *sorted(versions)]:
        if ver in versions:
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
    ret += ifdef_pop(0)
    ret += "};\n"

    ret += "\n"
    ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
    ret += msg_table("msg", "name", "char *", (0, 0x100, 1))

    # bitmasks #################################################################
    ret += f"""
/* bitmasks *******************************************************************/
"""
    for typ in typs:
        if not isinstance(typ, idl.Bitfield):
            continue
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
        verwidth = max(len(ver) for ver in versions)
        for ver in sorted(versions):
            ret += ifdef_push(2, c_ver_ifdef({ver}))
            ret += (
                f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
                + "".join(
                    "1" if typ.bit_is_valid(bitname, ver) else "0"
                    for bitname in reversed(typ.bits)
                )
                + ",\n"
            )
        ret += ifdef_pop(1)
        ret += "};\n"
    ret += ifdef_pop(0)

    # validate_* ###############################################################
    ret += """
/* validate_* *****************************************************************/

LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
\t\t/* If needed-net-size overflowed uint32_t, then
\t\t * there's no way that actual-net-size will live up to
\t\t * that.  */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\tif (ctx->net_offset > ctx->net_size)
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}

LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
\t\t/* If needed-host-size overflowed size_t, then there's
\t\t * no way that actual-net-size will live up to
\t\t * that.  */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}

LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
                                         size_t cnt,
                                         _validate_fn_t item_fn, size_t item_host_size) {
\tfor (size_t i = 0; i < cnt; i++)
\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
\t\t\treturn true;
\treturn false;
}

LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""
    for typ in topo_sorted(typs):
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"

        match typ:
            case idl.Number():
                ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
            case idl.Bitfield():
                ret += f"\t if (validate_{typ.static_size}(ctx))\n"
                ret += "\t\treturn true;\n"
                ret += (
                    f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
                )
                if typ.static_size == 1:
                    ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
                else:
                    ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
                ret += f"\tif (val & ~mask)\n"
                ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
                ret += "\treturn false;\n"
            case idl.Struct():  # and idl.Message()
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                def should_save_value(member: idl.StructMember) -> bool:
                    nonlocal typ
                    assert isinstance(typ, idl.Struct)
                    return bool(
                        member.max
                        or member.val
                        or any(m.cnt == member for m in typ.members)
                    )

                # Pass 1 - declare value variables
                for member in typ.members:
                    if should_save_value(member):
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t{c_typename(member.typ)} {member.name};\n"
                ret += ifdef_pop(1)

                # Pass 2 - declare offset variables
                mark_offset: set[str] = set()
                for member in typ.members:
                    for tok in [*member.max.tokens, *member.val.tokens]:
                        if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
                            if tok.name[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
                            mark_offset.add(tok.name[1:])

                # Pass 3 - main pass
                ret += "\treturn false\n"
                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += f"\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c_ver_cond(member.in_versions) + " && "
                    if member.cnt is not None:
                        ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
                        if typ.name == "s":  # SPECIAL (string)
                            ret += f'\n\t    || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})'
                    else:
                        if should_save_value(member):
                            ret += "("
                        if member.name in mark_offset:
                            ret += f"({{ _{member.name}_offset = ctx->net_offset; "
                        ret += f"validate_{member.typ.name}(ctx)"
                        if member.name in mark_offset:
                            ret += "; })"
                        if should_save_value(member):
                            nbytes = member.static_size
                            assert nbytes
                            if nbytes == 1:
                                ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
                            else:
                                ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"

                # Pass 4 - validate ,max= and ,val= constraints
                for member in typ.members:
                    if member.max:
                        assert member.static_size
                        nbits = member.static_size * 8
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint{nbits}_t max = {c_expr(member.max)}; (((uint{nbits}_t){member.name}) > max) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n'
                    if member.val:
                        assert member.static_size
                        nbits = member.static_size * 8
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint{nbits}_t exp = {c_expr(member.val)}; (((uint{nbits}_t){member.name}) != exp) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n'

                ret += ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += ifdef_pop(0)

    # unmarshal_* ##############################################################
    ret += """
/* unmarshal_* ****************************************************************/

LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
\t*out = ctx->net_bytes[ctx->net_offset];
\tctx->net_offset += 1;
}

LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 2;
}

LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 4;
}

LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 8;
}
"""
    for typ in topo_sorted(typs):
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
        match typ:
            case idl.Number():
                ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
            case idl.Bitfield():
                ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
            case idl.Struct():
                ret += "\tmemset(out, 0, sizeof(*out));\n"

                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    if member.val:
                        ret += f"\tctx->net_offset += {member.static_size};\n"
                        continue
                    ret += "\t"

                    prefix = "\t"
                    if member.in_versions != typ.in_versions:
                        ret += "if ( " + c_ver_cond(member.in_versions) + " ) "
                        prefix = "\t\t"
                    if member.cnt:
                        if member.in_versions != typ.in_versions:
                            ret += "{\n"
                            ret += prefix
                        ret += f"out->{member.name} = ctx->extra;\n"
                        ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n"
                        ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n"
                        if member.typ.static_size == 1:  # SPECIAL (string)
                            # Special-case is that we cast from `char` to `uint8_t`.
                            ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
                        else:
                            ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
                        if member.in_versions != typ.in_versions:
                            ret += "\t}\n"
                    else:
                        ret += (
                            f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
                        )
        ret += ifdef_pop(1)
        ret += "}\n"
    ret += ifdef_pop(0)

    # marshal_* ################################################################
    ret += """
/* marshal_* ******************************************************************/

LM_ALWAYS_INLINE static bool _marshal_too_large(struct _marshal_ctx *ctx) {
\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
\t\t(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
\t\tctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
\t\tctx->ctx->max_msg_size);
\treturn true;
}

LM_ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
\tif (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
\t\treturn _marshal_too_large(ctx);
\tctx->net_bytes[ctx->net_offset] = *val;
\tctx->net_offset += 1;
\treturn false;
}

LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
\tif (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
\t\treturn _marshal_too_large(ctx);
\tuint16le_encode(&ctx->net_bytes[ctx->net_offset], *val);
\tctx->net_offset += 2;
\treturn false;
}

LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
\tif (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
\t\treturn true;
\tuint32le_encode(&ctx->net_bytes[ctx->net_offset], *val);
\tctx->net_offset += 4;
\treturn false;
}

LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
\tif (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
\t\treturn true;
\tuint64le_encode(&ctx->net_bytes[ctx->net_offset], *val);
\tctx->net_offset += 8;
\treturn false;
}
"""
    for typ in topo_sorted(typs):
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n"
        match typ:
            case idl.Number():
                ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n"
            case idl.Bitfield():
                ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n"
                ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n"
            case idl.Struct():
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                # Pass 1 - declare offset variables
                mark_offset = set()
                for member in typ.members:
                    if member.val:
                        if member.name not in mark_offset:
                            ret += f"\tuint32_t _{member.name}_offset;\n"
                        mark_offset.add(member.name)
                    for tok in member.val.tokens:
                        if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
                            if tok.name[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
                            mark_offset.add(tok.name[1:])

                # Pass 2 - main pass
                ret += "\treturn false\n"
                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += "\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c_ver_cond(member.in_versions) + " && "
                    if member.name in mark_offset:
                        ret += f"({{ _{member.name}_offset = ctx->net_offset; "
                    if member.cnt:
                        ret += "({ bool err = false;\n"
                        ret += f"\t          for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n"
                        ret += "\t          \terr = "
                        if member.typ.static_size == 1:  # SPECIAL (string)
                            # Special-case is that we cast from `char` to `uint8_t`.
                            ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n"
                        else:
                            ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n"
                        ret += f"\t          err; }})"
                    elif member.val:
                        # Just increment net_offset, don't actually marshal anything (yet).
                        assert member.static_size
                        ret += (
                            f"({{ ctx->net_offset += {member.static_size}; false; }})"
                        )
                    else:
                        ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
                    if member.name in mark_offset:
                        ret += "; })"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"

                # Pass 3 - marshal ,val= members
                for member in typ.members:
                    if member.val:
                        assert member.static_size
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        if member.static_size == 1:
                            ret += f"\t    || ({{ ctx->net_bytes[_{member.name}_offset] = {c_expr(member.val)}; false; }})\n"
                        else:
                            ret += f"\t    || ({{ uint{member.static_size*8}le_encode(&ctx->net_bytes[_{member.name}_offset], {c_expr(member.val)}); false; }})\n"

                ret += ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += ifdef_pop(0)

    # function tables ##########################################################
    ret += """
/* function tables ************************************************************/

"""
    ret += c_macro(
        f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
        f"\t\t.basesize  = sizeof(struct {idprefix}msg_##typ),\n"
        f"\t\t.validate  = validate_##typ,\n"
        f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n"
        f"\t}}\n"
    )
    ret += c_macro(
        f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n"
        f"\t\t.marshal   = (_marshal_fn_t)marshal_##typ,\n"
        f"\t}}\n"
    )
    ret += "\n"
    ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2))
    ret += "\n"
    ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2))
    ret += "\n"
    ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2))
    ret += "\n"
    ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2))

    ret += f"""
LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{
\treturn validate_stat(ctx);
}}
LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct lib9p_stat *out) {{
\tunmarshal_stat(ctx, out);
}}
LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct lib9p_stat *val) {{
\treturn marshal_stat(ctx, val);
}}
"""

    ############################################################################
    return ret


# Main #########################################################################


if __name__ == "__main__":
    import sys

    if typing.TYPE_CHECKING:

        class ANSIColors:
            MAGENTA = "\x1b[35m"
            RED = "\x1b[31m"
            RESET = "\x1b[0m"

    else:
        from _colorize import ANSIColors  # Present in Python 3.13+

    if len(sys.argv) < 2:
        raise ValueError("requires at least 1 .9p filename")
    parser = idl.Parser()
    for txtname in sys.argv[1:]:
        try:
            parser.parse_file(txtname)
        except SyntaxError as e:
            print(
                f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}",
                file=sys.stderr,
            )
            assert e.text
            print(f"\t{e.text}", file=sys.stderr)
            print(
                f"\t{ANSIColors.RED}{'~'*len(e.text)}{ANSIColors.RESET}",
                file=sys.stderr,
            )
            sys.exit(2)
    versions, typs = parser.all()
    outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
    with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
        fh.write(gen_h(versions, typs))
    with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
        fh.write(gen_c(versions, typs))