#!/usr/bin/env python
# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files
#                 defining 9P protocol variants.
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import os.path
import sys

sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))

import idl

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# Utilities ####################################################################

idprefix = "lib9p_"


def tab_ljust(s: str, width: int) -> str:
    cur = len(s.expandtabs(tabsize=8))
    if cur >= width:
        return s
    return s + " " * (width - cur)


def add_prefix(p: str, s: str) -> str:
    if s.startswith("_"):
        return "_" + p + s[1:]
    return p + s


def join_lines(*args: str) -> str:
    return "\n".join([a.rstrip() for a in args]).rstrip() + "\n"


def c_macro(*args: str) -> str:
    full = join_lines(*args).rstrip()
    assert "\n" in full
    lines = [l.rstrip() for l in full.split("\n")]
    width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1])
    lines = [tab_ljust(l, width) for l in lines]
    return "  \\\n".join(lines).rstrip() + "\n"


def c_ver_enum(ver: str) -> str:
    return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"


def c_ver_ifdef(versions: set[str]) -> str:
    return " || ".join(
        f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
    )


def c_ver_cond(versions: set[str]) -> str:
    if len(versions) == 1:
        return f"(ctx->ctx->version=={c_ver_enum(next(v for v in versions))})"
    return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"


def c_typename(typ: idl.Type) -> str:
    match typ:
        case idl.Primitive():
            return f"uint{typ.value*8}_t"
        case idl.Number():
            return f"{idprefix}{typ.name}_t"
        case idl.Bitfield():
            return f"{idprefix}{typ.name}_t"
        case idl.Message():
            return f"struct {idprefix}msg_{typ.name}"
        case idl.Struct():
            return f"struct {idprefix}{typ.name}"
        case _:
            raise ValueError(f"not a type: {typ.__class__.__name__}")


def c_expr(expr: idl.Expr) -> str:
    ret: list[str] = []
    for tok in expr.tokens:
        match tok:
            case idl.ExprOp():
                ret += [tok.op]
            case idl.ExprLit():
                ret += [str(tok.val)]
            case idl.ExprSym(name="end"):
                ret += ["ctx->net_offset"]
            case idl.ExprSym():
                ret += [f"_{tok.name[1:]}_offset"]
    return " ".join(ret)


_ifdef_stack: list[str | None] = []


def ifdef_push(n: int, _newval: str) -> str:
    # Grow the stack as needed
    global _ifdef_stack
    while len(_ifdef_stack) < n:
        _ifdef_stack += [None]

    # Set some variables
    parentval: str | None = None
    for x in _ifdef_stack[:-1]:
        if x is not None:
            parentval = x
    oldval = _ifdef_stack[-1]
    newval: str | None = _newval
    if newval == parentval:
        newval = None

    # Put newval on the stack.
    _ifdef_stack[-1] = newval

    # Build output.
    ret = ""
    if newval != oldval:
        if oldval is not None:
            ret += f"#endif /* {oldval} */\n"
        if newval is not None:
            ret += f"#if {newval}\n"
    return ret


def ifdef_pop(n: int) -> str:
    global _ifdef_stack
    ret = ""
    while len(_ifdef_stack) > n:
        if _ifdef_stack[-1] is not None:
            ret += f"#endif /* {_ifdef_stack[-1]} */\n"
        _ifdef_stack = _ifdef_stack[:-1]
    return ret


# Generate .h ##################################################################


def gen_h(versions: set[str], typs: list[idl.Type]) -> str:
    global _ifdef_stack
    _ifdef_stack = []

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#ifndef _LIB9P_9P_H_
\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
#endif

#include <stdint.h> /* for uint{{n}}_t types */
"""

    ret += f"""
/* config *********************************************************************/

#include "config.h"
"""
    for ver in sorted(versions):
        ret += "\n"
        ret += f"#ifndef {c_ver_ifdef({ver})}\n"
        ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n"
        ret += "#endif\n"

    ret += f"""
/* enum version ***************************************************************/

enum {idprefix}version {{
"""
    fullversions = ["unknown = 0", *sorted(versions)]
    verwidth = max(len(v) for v in fullversions)
    for ver in fullversions:
        if ver in versions:
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f"\t{c_ver_enum(ver)},"
        ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
    ret += ifdef_pop(0)
    ret += f"\t{c_ver_enum('NUM')},\n"
    ret += "};\n"
    ret += "\n"
    ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"

    ret += """
/* enum msg_type **************************************************************/

"""
    ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
    namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message))
    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
        ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
        ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
    ret += ifdef_pop(0)
    ret += "};\n"

    ret += """
/* payload types **************************************************************/
"""
    for typ in typs:
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        match typ:
            case idl.Number():
                ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
            case idl.Bitfield():
                ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
                names = [
                    *reversed(
                        [typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))]
                    ),
                    "",
                    *[k for k in typ.names if k not in typ.bits],
                ]
                prefix = f"{idprefix.upper()}{typ.name.upper()}_"
                namewidth = max(len(add_prefix(prefix, name)) for name in names)

                ret += "\n"
                for name in names:
                    if name == "":
                        ret += "\n"
                        continue

                    if name.startswith(" "):
                        vers = typ.in_versions
                        c_name = ""
                        c_val = f"1<<{name[1:]}"
                    else:
                        vers = typ.names[name].in_versions
                        c_name = add_prefix(prefix, name)
                        c_val = f"{typ.names[name].val}"

                    ret += ifdef_push(2, c_ver_ifdef(vers))

                    # It is important all of the `beg` strings have
                    # the same length.
                    end = ""
                    if name.startswith(" "):
                        beg = "/* unused"
                        end = " */"
                    elif _ifdef_stack[-1]:
                        beg = "#  define"
                    else:
                        beg = "#define  "

                    ret += f"{beg} {c_name.ljust(namewidth)}  (({c_typename(typ)})({c_val})){end}\n"
                ret += ifdef_pop(1)
            case idl.Struct():  # and idl.Message():
                ret += c_typename(typ) + " {"
                if not typ.members:
                    ret += "};\n"
                    continue
                ret += "\n"

                typewidth = max(len(c_typename(m.typ)) for m in typ.members)

                for member in typ.members:
                    if member.val:
                        continue
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    c_type = c_typename(member.typ)
                    if (typ.name in ["d", "s"]) and member.cnt:  # SPECIAL
                        c_type = "char"
                    ret += f"\t{c_type.ljust(typewidth)}  {'*' if member.cnt else ' '}{member.name};\n"
                ret += ifdef_pop(1)
                ret += "};\n"
    ret += ifdef_pop(0)

    return ret


# Generate .c ##################################################################


def gen_c(versions: set[str], typs: list[idl.Type]) -> str:
    global _ifdef_stack
    _ifdef_stack = []

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#include <stdbool.h>
#include <stddef.h>   /* for size_t */
#include <inttypes.h> /* for PRI* macros */
#include <string.h>   /* for memset() */

#include <libmisc/assert.h>

#include <lib9p/9p.h>

#include "internal.h"
"""

    def used(arg: str) -> str:
        return arg

    def unused(arg: str) -> str:
        return f"LM_UNUSED({arg})"

    # strings ##################################################################
    ret += f"""
/* strings ********************************************************************/

static const char *version_strs[{c_ver_enum('NUM')}] = {{
"""
    for ver in ["unknown", *sorted(versions)]:
        if ver in versions:
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
    ret += ifdef_pop(0)
    ret += "};\n"
    ret += f"""
const char *{idprefix}version_str(enum {idprefix}version ver) {{
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtype-limits"
    assert(0 <= ver && ver < {c_ver_enum('NUM')});
#pragma GCC diagnostic pop
    return version_strs[ver];
}}
"""

    # bitmasks #################################################################
    ret += f"""
/* bitmasks *******************************************************************/
"""
    for typ in typs:
        if not isinstance(typ, idl.Bitfield):
            continue
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
        verwidth = max(len(ver) for ver in versions)
        for ver in sorted(versions):
            ret += ifdef_push(2, c_ver_ifdef({ver}))
            ret += (
                f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
                + "".join(
                    "1" if typ.bit_is_valid(bitname, ver) else "0"
                    for bitname in reversed(typ.bits)
                )
                + ",\n"
            )
        ret += ifdef_pop(1)
        ret += "};\n"
    ret += ifdef_pop(0)

    # validate_* ###############################################################
    ret += """
/* validate_* *****************************************************************/

LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
\t\t/* If needed-net-size overflowed uint32_t, then
\t\t * there's no way that actual-net-size will live up to
\t\t * that.  */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\tif (ctx->net_offset > ctx->net_size)
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}

LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
\t\t/* If needed-host-size overflowed size_t, then there's
\t\t * no way that actual-net-size will live up to
\t\t * that.  */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}

LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
                                         size_t cnt,
                                         _validate_fn_t item_fn, size_t item_host_size) {
\tfor (size_t i = 0; i < cnt; i++)
\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
\t\t\treturn true;
\treturn false;
}

#define validate_1(ctx) _validate_size_net(ctx, 1)
#define validate_2(ctx) _validate_size_net(ctx, 2)
#define validate_4(ctx) _validate_size_net(ctx, 4)
#define validate_8(ctx) _validate_size_net(ctx, 8)
"""
    for typ in typs:
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"

        if typ.name == "d":  # SPECIAL
            # Optimize... maybe the compiler could figure out to do
            # this, but let's make it obvious.
            ret += "\tuint32_t base_offset = ctx->net_offset;\n"
            ret += "\tif (validate_4(ctx))\n"
            ret += "\t\treturn true;\n"
            ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
            ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
            ret += "}\n"
            continue
        if typ.name == "s":  # SPECIAL
            # Add an extra nul-byte on the host, and validate UTF-8
            # (also, similar optimization to "d").
            ret += "\tuint32_t base_offset = ctx->net_offset;\n"
            ret += "\tif (validate_2(ctx))\n"
            ret += "\t\treturn true;\n"
            ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
            ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
            ret += "\t\treturn true;\n"
            ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
            ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
            ret += "\treturn false;\n"
            ret += "}\n"
            continue

        match typ:
            case idl.Number():
                ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
            case idl.Bitfield():
                ret += f"\t if (validate_{typ.static_size}(ctx))\n"
                ret += "\t\treturn true;\n"
                ret += (
                    f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
                )
                ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
                ret += f"\tif (val & ~mask)\n"
                ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
                ret += "\treturn false;\n"
            case idl.Struct():  # and idl.Message()
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                # Pass 1 - declare value variables
                for member in typ.members:
                    if member.max or member.val:
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t{c_typename(member.typ)} {member.name};\n"
                ret += ifdef_pop(1)

                # Pass 2 - declare offset variables
                mark_offset: set[str] = set()
                for member in typ.members:
                    for tok in [*member.max.tokens, *member.val.tokens]:
                        if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
                            if tok.name[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
                            mark_offset.add(tok.name[1:])

                # Pass 3 - main pass
                ret += "\treturn false\n"
                prev_size: int | None = None
                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += f"\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c_ver_cond(member.in_versions) + " && "
                    if member.cnt is not None:
                        assert prev_size
                        ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
                    else:
                        if member.max or member.val:
                            ret += "("
                        if member.name in mark_offset:
                            ret += f"({{ _{member.name}_offset = ctx->net_offset; "
                        ret += f"validate_{member.typ.name}(ctx)"
                        if member.name in mark_offset:
                            ret += "; })"
                        if member.max or member.val:
                            bytes = member.static_size
                            assert bytes
                            bits = bytes * 8
                            ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"
                    prev_size = member.static_size

                # Pass 4 - validate ,max= and ,val= constraints
                for member in typ.members:
                    if member.max:
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
                    if member.val:
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'

                ret += ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += ifdef_pop(0)

    # unmarshal_* ##############################################################
    ret += """
/* unmarshal_* ****************************************************************/

LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
\t*out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 1;
}

LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
\t*out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 2;
}

LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
\t*out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 4;
}

LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
\t*out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 8;
}
"""
    for typ in typs:
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
        match typ:
            case idl.Number():
                ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
            case idl.Bitfield():
                ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
            case idl.Struct():
                ret += "\tmemset(out, 0, sizeof(*out));\n"

                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    if member.val:
                        ret += f"\tctx->net_offset += {member.static_size};\n"
                        continue
                    ret += "\t"

                    prefix = "\t"
                    if member.in_versions != typ.in_versions:
                        ret += "if ( " + c_ver_cond(member.in_versions) + " ) "
                        prefix = "\t\t"
                    if member.cnt:
                        if member.in_versions != typ.in_versions:
                            ret += "{\n"
                            ret += prefix
                        ret += f"out->{member.name} = ctx->extra;\n"
                        ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
                        ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
                        if typ.name in ["d", "s"]:  # SPECIAL
                            ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
                        else:
                            ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
                        if member.in_versions != typ.in_versions:
                            ret += "\t}\n"
                    else:
                        ret += (
                            f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
                        )
                if typ.name == "s":  # SPECIAL
                    ret += "\tctx->extra++;\n"
                    ret += "\tout->utf8[out->len] = '\\0';\n"
        ret += ifdef_pop(1)
        ret += "}\n"
    ret += ifdef_pop(0)

    # marshal_* ################################################################
    ret += """
/* marshal_* ******************************************************************/

LM_ALWAYS_INLINE static bool _marshal_too_large(struct _marshal_ctx *ctx) {
\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
\t\t(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
\t\tctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
\t\tctx->ctx->max_msg_size);
\treturn true;
}

LM_ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
\tif (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
\t\treturn _marshal_too_large(ctx);
\tctx->net_bytes[ctx->net_offset] = *val;
\tctx->net_offset += 1;
\treturn false;
}

LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
\tif (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
\t\treturn _marshal_too_large(ctx);
\tencode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 2;
\treturn false;
}

LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
\tif (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
\t\treturn true;
\tencode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 4;
\treturn false;
}

LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
\tif (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
\t\treturn true;
\tencode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
\tctx->net_offset += 8;
\treturn false;
}
"""
    for typ in typs:
        inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n"
        match typ:
            case idl.Number():
                ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n"
            case idl.Bitfield():
                ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n"
                ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n"
            case idl.Struct():
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                # Pass 1 - declare offset variables
                mark_offset = set()
                for member in typ.members:
                    if member.val:
                        if member.name not in mark_offset:
                            ret += f"\tuint32_t _{member.name}_offset;\n"
                        mark_offset.add(member.name)
                    for tok in member.val.tokens:
                        if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"):
                            if tok.name[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
                            mark_offset.add(tok.name[1:])

                # Pass 2 - main pass
                ret += "\treturn false\n"
                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += "\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c_ver_cond(member.in_versions) + " && "
                    if member.name in mark_offset:
                        ret += f"({{ _{member.name}_offset = ctx->net_offset; "
                    if member.cnt:
                        ret += "({ bool err = false;\n"
                        ret += f"\t          for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n"
                        ret += "\t          \terr = "
                        if typ.name in ["d", "s"]:  # SPECIAL
                            # Special-case is that we cast from `char` to `uint8_t`.
                            ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n"
                        else:
                            ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n"
                        ret += f"\t          err; }})"
                    elif member.val:
                        # Just increment net_offset, don't actually marshal anything (yet).
                        assert member.static_size
                        ret += (
                            f"({{ ctx->net_offset += {member.static_size}; false; }})"
                        )
                    else:
                        ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
                    if member.name in mark_offset:
                        ret += "; })"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"

                # Pass 3 - marshal ,val= members
                for member in typ.members:
                    if member.val:
                        assert member.static_size
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n"

                ret += ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += ifdef_pop(0)

    # tables / exports #########################################################
    ret += """
/* tables / exports ***********************************************************/
"""
    id2typ: dict[int, idl.Message] = {}
    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
        id2typ[msg.msgid] = msg

    ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n"
    ret += c_macro(
        f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{",
        f"\t\t.basesize  = sizeof(struct {idprefix}msg_##typ),",
        f"\t\t.validate  = validate_##typ,",
        f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,",
        f"\t}}",
    )
    ret += c_macro(
        f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{",
        f"\t\t.marshal   = (_marshal_fn_t)marshal_##typ,",
        f"\t}}",
    )

    tables = [
        ("msg", "name", "char *", (0, 0x100, 1)),
        ("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)),
        ("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)),
        ("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)),
        ("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)),
    ]
    for grp, meth, tentry, rng in tables:
        ret += "\n"
        ret += f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
        for ver in ["unknown", *sorted(versions)]:
            if ver != "unknown":
                ret += ifdef_push(1, c_ver_ifdef({ver}))
            ret += f"\t[{c_ver_enum(ver)}] = {{\n"
            for n in range(*rng):
                xmsg: idl.Message | None = id2typ.get(n, None)
                if xmsg:
                    if ver == "unknown":  # SPECIAL
                        if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
                            xmsg = None
                    else:
                        if ver not in xmsg.in_versions:
                            xmsg = None
                if xmsg:
                    ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n"
            ret += "\t},\n"
        ret += ifdef_pop(0)
        ret += "};\n"

    ret += f"""
LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{
\treturn validate_stat(ctx);
}}
LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct lib9p_stat *out) {{
\tunmarshal_stat(ctx, out);
}}
LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct lib9p_stat *val) {{
\treturn marshal_stat(ctx, val);
}}
"""

    ############################################################################
    return ret


# Main #########################################################################


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        raise ValueError("requires at least 1 .9p filename")
    parser = idl.Parser()
    for txtname in sys.argv[1:]:
        parser.parse_file(txtname)
    versions, typs = parser.all()
    outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
    with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
        fh.write(gen_h(versions, typs))
    with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
        fh.write(gen_c(versions, typs))