#!/usr/bin/env python
# lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files
#                 defining 9P protocol variants.
#
# Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import enum
import os.path
import re
from abc import ABC, abstractmethod
from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".

# Types ########################################################################


class Primitive(enum.Enum):
    u8 = 1
    u16 = 2
    u32 = 4
    u64 = 8

    @property
    def in_versions(self) -> set[str]:
        return set()

    @property
    def name(self) -> str:
        return str(self.value)

    @property
    def static_size(self) -> int:
        return self.value


class Number:
    name: str
    in_versions: set[str]

    prim: Primitive

    def __init__(self) -> None:
        self.in_versions = set()

    @property
    def static_size(self) -> int:
        return self.prim.static_size


class BitfieldVal:
    name: str
    in_versions: set[str]

    val: str

    def __init__(self) -> None:
        self.in_versions = set()


class Bitfield:
    name: str
    in_versions: set[str]

    prim: Primitive

    bits: list[str]  # bitnames
    names: dict[str, BitfieldVal]  # bits *and* aliases

    def __init__(self) -> None:
        self.in_versions = set()
        self.names = {}

    @property
    def static_size(self) -> int:
        return self.prim.static_size

    def bit_is_valid(self, bit: str | int, ver: str | None = None) -> bool:
        """Return whether the given bit is valid in the given protocol
        version.

        """
        bitname = self.bits[bit] if isinstance(bit, int) else bit
        assert bitname in self.bits
        if not bitname:
            return False
        if bitname.startswith("_"):
            return False
        if ver and (ver not in self.names[bitname].in_versions):
            return False
        return True


class ExprLit:
    val: int

    def __init__(self, val: int) -> None:
        self.val = val


class ExprSym:
    name: str

    def __init__(self, name: str) -> None:
        self.name = name


class ExprOp:
    op: Literal["-", "+"]

    def __init__(self, op: Literal["-", "+"]) -> None:
        self.op = op


class Expr:
    tokens: list[ExprLit | ExprSym | ExprOp]

    def __init__(self) -> None:
        self.tokens = []

    def __bool__(self) -> bool:
        return len(self.tokens) > 0


class StructMember:
    # from left-to-right when parsing
    cnt: str | None = None
    name: str
    typ: "Type"
    max: Expr
    val: Expr

    in_versions: set[str]

    @property
    def static_size(self) -> int | None:
        if self.cnt:
            return None
        return self.typ.static_size


class Struct:
    name: str
    in_versions: set[str]

    members: list[StructMember]

    def __init__(self) -> None:
        self.in_versions = set()

    @property
    def static_size(self) -> int | None:
        size = 0
        for member in self.members:
            msize = member.static_size
            if msize is None:
                return None
            size += msize
        return size


class Message(Struct):
    @property
    def msgid(self) -> int:
        assert len(self.members) >= 3
        assert self.members[1].name == "typ"
        assert self.members[1].static_size == 1
        assert self.members[1].val
        assert len(self.members[1].val.tokens) == 1
        assert isinstance(self.members[1].val.tokens[0], ExprLit)
        return self.members[1].val.tokens[0].val


Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message
# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13
T = TypeVar("T", Number, Bitfield, Struct, Message)

# Parse *.9p ###################################################################

re_priname = "(?:1|2|4|8)"  # primitive names
re_symname = "(?:[a-zA-Z_][a-zA-Z_0-9]*)"  # "symbol" names; most *.9p-defined names
re_impname = r"(?:\*|" + re_symname + ")"  # names we can import
re_msgname = r"(?:[TR][a-zA-Z_0-9]*)"  # names a message can be

re_memtype = f"(?:{re_symname}|{re_priname})"  # typenames that a struct member can be

re_expr = f"(?:(?:-|\\+|[0-9]+|&?{re_symname})+)"

re_bitspec_bit = f"(?P<bit>[0-9]+)\\s*=\\s*(?P<name>{re_symname})"
re_bitspec_alias = f"(?P<name>{re_symname})\\s*=\\s*(?P<val>\\S+)"

re_memberspec = f"(?:(?P<cnt>{re_symname})\\*\\()?(?P<name>{re_symname})\\[(?P<typ>{re_memtype})(?:,max=(?P<max>{re_expr})|,val=(?P<val>{re_expr}))*\\]\\)?"


def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None:
    spec = spec.strip()

    bit: int | None
    val: BitfieldVal
    if m := re.fullmatch(re_bitspec_bit, spec):
        bit = int(m.group("bit"))
        name = m.group("name")

        val = BitfieldVal()
        val.name = name
        val.val = f"1<<{bit}"
        val.in_versions.add(ver)

        if bit < 0 or bit >= len(bf.bits):
            raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds")
        if bf.bits[bit]:
            raise ValueError(f"{bf.name}: bit {bit} already assigned")
        bf.bits[bit] = val.name
    elif m := re.fullmatch(re_bitspec_alias, spec):
        name = m.group("name")
        valstr = m.group("val")

        val = BitfieldVal()
        val.name = name
        val.val = valstr
        val.in_versions.add(ver)
    else:
        raise SyntaxError(f"invalid bitfield spec {repr(spec)}")

    if val.name in bf.names:
        raise ValueError(f"{bf.name}: name {val.name} already assigned")
    bf.names[val.name] = val


def parse_expr(expr: str) -> Expr:
    assert re.fullmatch(re_expr, expr)
    ret = Expr()
    for tok in re.split("([-+])", expr):
        if tok == "-" or tok == "+":
            # I, for the life of me, do not understand why I need this
            # cast() to keep mypy happy.
            ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))]
        elif re.fullmatch("[0-9]+", tok):
            ret.tokens += [ExprLit(int(tok))]
        else:
            ret.tokens += [ExprSym(tok)]
    return ret


def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> None:
    for spec in specs.split():
        m = re.fullmatch(re_memberspec, spec)
        if not m:
            raise SyntaxError(f"invalid member spec {repr(spec)}")

        member = StructMember()
        member.in_versions = {ver}

        member.name = m.group("name")
        if any(x.name == member.name for x in struct.members):
            raise ValueError(f"duplicate member name {repr(member.name)}")

        if m.group("typ") not in env:
            raise NameError(f"Unknown type {repr(m.group(2))}")
        member.typ = env[m.group("typ")]

        if cnt := m.group("cnt"):
            if len(struct.members) == 0 or struct.members[-1].name != cnt:
                raise ValueError(f"list count must be previous item: {repr(cnt)}")
            if not isinstance(struct.members[-1].typ, Primitive):
                raise ValueError(f"list count must be an integer type: {repr(cnt)}")
            member.cnt = cnt

        if maxstr := m.group("max"):
            if (not isinstance(member.typ, Primitive)) or member.cnt:
                raise ValueError("',max=' may only be specified on a non-repeated atom")
            member.max = parse_expr(maxstr)
        else:
            member.max = Expr()

        if valstr := m.group("val"):
            if (not isinstance(member.typ, Primitive)) or member.cnt:
                raise ValueError("',val=' may only be specified on a non-repeated atom")
            member.val = parse_expr(valstr)
        else:
            member.val = Expr()

        struct.members += [member]


def re_string(grpname: str) -> str:
    return f'"(?P<{grpname}>[^"]*)"'


re_line_version = f"version\\s+{re_string('version')}"
re_line_import = f"from\\s+(?P<file>\\S+)\\s+import\\s+(?P<syms>{re_impname}(?:\\s*,\\s*{re_impname})*)"
re_line_num = f"num\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
re_line_bitfield = f"bitfield\\s+(?P<name>{re_symname})\\s*=\\s*(?P<prim>{re_priname})"
re_line_bitfield_ = (
    f"bitfield\\s+(?P<name>{re_symname})\\s*\\+=\\s*{re_string('member')}"
)
re_line_struct = (
    f"struct\\s+(?P<name>{re_symname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
)
re_line_msg = (
    f"msg\\s+(?P<name>{re_msgname})\\s*(?P<op>\\+?=)\\s*{re_string('members')}"
)
re_line_cont = f"\\s+{re_string('specs')}"  # could be bitfield/struct/msg


def parse_file(
    filename: str, get_include: Callable[[str], tuple[str, list[Type]]]
) -> tuple[str, list[Type]]:
    version: str | None = None
    env: dict[str, Type] = {
        "1": Primitive.u8,
        "2": Primitive.u16,
        "4": Primitive.u32,
        "8": Primitive.u64,
    }

    def get_type(name: str, tc: type[T]) -> T:
        nonlocal env
        if name not in env:
            raise NameError(f"Unknown type {repr(name)}")
        ret = env[name]
        if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__):
            raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}")
        return ret

    with open(filename, "r") as fh:
        prev: Type | None = None
        for line in fh:
            line = line.split("#", 1)[0].rstrip()
            if not line:
                continue
            if m := re.fullmatch(re_line_version, line):
                if version:
                    raise SyntaxError("must have exactly 1 version line")
                version = m.group("version")
                continue
            if not version:
                raise SyntaxError("must have exactly 1 version line")

            if m := re.fullmatch(re_line_import, line):
                other_version, other_typs = get_include(m.group("file"))
                for symname in m.group("syms").split(sep=","):
                    symname = symname.strip()
                    for typ in other_typs:
                        if typ.name == symname or symname == "*":
                            match typ:
                                case Primitive():
                                    pass
                                case Number():
                                    typ.in_versions.add(version)
                                case Bitfield():
                                    typ.in_versions.add(version)
                                    for val in typ.names.values():
                                        if other_version in val.in_versions:
                                            val.in_versions.add(version)
                                case Struct():  # and Message()
                                    typ.in_versions.add(version)
                                    for member in typ.members:
                                        if other_version in member.in_versions:
                                            member.in_versions.add(version)
                            env[typ.name] = typ
            elif m := re.fullmatch(re_line_num, line):
                num = Number()
                num.name = m.group("name")
                num.in_versions.add(version)

                prim = env[m.group("prim")]
                assert isinstance(prim, Primitive)
                num.prim = prim

                env[num.name] = num
                prev = num
            elif m := re.fullmatch(re_line_bitfield, line):
                bf = Bitfield()
                bf.name = m.group("name")
                bf.in_versions.add(version)

                prim = env[m.group("prim")]
                assert isinstance(prim, Primitive)
                bf.prim = prim

                bf.bits = (prim.static_size * 8) * [""]

                env[bf.name] = bf
                prev = bf
            elif m := re.fullmatch(re_line_bitfield_, line):
                bf = get_type(m.group("name"), Bitfield)
                parse_bitspec(version, bf, m.group("member"))

                prev = bf
            elif m := re.fullmatch(re_line_struct, line):
                match m.group("op"):
                    case "=":
                        struct = Struct()
                        struct.name = m.group("name")
                        struct.in_versions.add(version)
                        struct.members = []
                        parse_members(version, env, struct, m.group("members"))

                        env[struct.name] = struct
                        prev = struct
                    case "+=":
                        struct = get_type(m.group("name"), Struct)
                        parse_members(version, env, struct, m.group("members"))

                        prev = struct
            elif m := re.fullmatch(re_line_msg, line):
                match m.group("op"):
                    case "=":
                        msg = Message()
                        msg.name = m.group("name")
                        msg.in_versions.add(version)
                        msg.members = []
                        parse_members(version, env, msg, m.group("members"))

                        env[msg.name] = msg
                        prev = msg
                    case "+=":
                        msg = get_type(m.group("name"), Message)
                        parse_members(version, env, msg, m.group("members"))

                        prev = msg
            elif m := re.fullmatch(re_line_cont, line):
                match prev:
                    case Bitfield():
                        parse_bitspec(version, prev, m.group("specs"))
                    case Struct():  # and Message()
                        parse_members(version, env, prev, m.group("specs"))
                    case _:
                        raise SyntaxError(
                            "continuation line must come after a bitfield, struct, or msg line"
                        )
            else:
                raise SyntaxError(f"invalid line {repr(line)}")
    if not version:
        raise SyntaxError("must have exactly 1 version line")

    typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)]

    for typ in [typ for typ in typs if isinstance(typ, Struct)]:
        valid_syms = ["end", *["&" + m.name for m in typ.members]]
        for member in typ.members:
            for tok in [*member.max.tokens, *member.val.tokens]:
                if isinstance(tok, ExprSym) and tok.name not in valid_syms:
                    raise ValueError(
                        f"{typ.name}.{member.name}: invalid sym: {tok.name}"
                    )

    return version, typs


# Generate C ###################################################################

idprefix = "lib9p_"


def c_ver_enum(ver: str) -> str:
    return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"


def c_ver_ifdef(versions: set[str]) -> str:
    return " || ".join(
        f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions)
    )


def c_ver_cond(versions: set[str]) -> str:
    if len(versions) == 1:
        return f"(ctx->ctx->version=={c_ver_enum(next(v for v in versions))})"
    return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"


def c_typename(typ: Type) -> str:
    match typ:
        case Primitive():
            return f"uint{typ.value*8}_t"
        case Number():
            return f"{idprefix}{typ.name}_t"
        case Bitfield():
            return f"{idprefix}{typ.name}_t"
        case Message():
            return f"struct {idprefix}msg_{typ.name}"
        case Struct():
            return f"struct {idprefix}{typ.name}"
        case _:
            raise ValueError(f"not a type: {typ.__class__.__name__}")


_ifdef_stack: list[str | None] = []


def ifdef_push(n: int, _newval: str) -> str:
    # Grow the stack as needed
    global _ifdef_stack
    while len(_ifdef_stack) < n:
        _ifdef_stack += [None]

    # Set some variables
    parentval: str | None = None
    for x in _ifdef_stack[:-1]:
        if x is not None:
            parentval = x
    oldval = _ifdef_stack[-1]
    newval: str | None = _newval
    if newval == parentval:
        newval = None

    # Put newval on the stack.
    _ifdef_stack[-1] = newval

    # Build output.
    ret = ""
    if newval != oldval:
        if oldval is not None:
            ret += f"#endif /* {oldval} */\n"
        if newval is not None:
            ret += f"#if {newval}\n"
    return ret


def ifdef_pop(n: int) -> str:
    global _ifdef_stack
    ret = ""
    while len(_ifdef_stack) > n:
        if _ifdef_stack[-1] is not None:
            ret += f"#endif /* {_ifdef_stack[-1]} */\n"
        _ifdef_stack = _ifdef_stack[:-1]
    return ret


def gen_h(versions: set[str], typs: list[Type]) -> str:
    global _ifdef_stack
    _ifdef_stack = []

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#ifndef _LIB9P_9P_H_
	#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
#endif

#include <stdint.h> /* for uint{{n}}_t types */
"""

    ret += f"""
/* config *********************************************************************/

#include "config.h"
"""
    for ver in sorted(versions):
        ret += "\n"
        ret += f"#ifndef {c_ver_ifdef({ver})}\n"
        ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n"
        ret += "#endif\n"

    ret += f"""
/* versions *******************************************************************/

enum {idprefix}version {{
"""
    fullversions = ["unknown = 0", *sorted(versions)]
    verwidth = max(len(v) for v in fullversions)
    for ver in fullversions:
        if ver in versions:
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f"\t{c_ver_enum(ver)},"
        ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
    ret += ifdef_pop(0)
    ret += f"\t{c_ver_enum('NUM')},\n"
    ret += "};\n"
    ret += "\n"
    ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"

    ret += """
/* non-message types **********************************************************/
"""
    for typ in [typ for typ in typs if not isinstance(typ, Message)]:
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        match typ:
            case Number():
                ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
            case Bitfield():
                ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
                names = [
                    *reversed(
                        [typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))]
                    ),
                    "",
                    *[k for k in typ.names if k not in typ.bits],
                ]
                namewidth = max(len(name) for name in names)

                ret += "\n"
                for name in names:
                    if name == "":
                        ret += "\n"
                    elif name.startswith(" "):
                        ret += ifdef_push(2, c_ver_ifdef(typ.in_versions))
                        sp = " " * (
                            len("#  define ")
                            + len(idprefix)
                            + len(typ.name)
                            + 1
                            + namewidth
                            + 2
                            - len("/* unused")
                        )
                        ret += f"/* unused{sp}(({c_typename(typ)})(1<<{name[1:]})) */\n"
                    else:
                        ret += ifdef_push(2, c_ver_ifdef(typ.names[name].in_versions))
                        if name.startswith("_"):
                            c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}"
                        else:
                            c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}"
                        sp1 = "  " if _ifdef_stack[-1] else ""
                        sp2 = " " if _ifdef_stack[-1] else "   "
                        sp3 = " " * (2 + namewidth - len(name))
                        ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(typ)})({typ.names[name].val}))\n"
                ret += ifdef_pop(1)
            case Struct():
                typewidth = max(len(c_typename(m.typ)) for m in typ.members)

                ret += c_typename(typ) + " {\n"
                for member in typ.members:
                    if member.val:
                        continue
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    c_type = c_typename(member.typ)
                    if (typ.name in ["d", "s"]) and member.cnt:  # SPECIAL
                        c_type = "char"
                    ret += f"\t{c_type.ljust(typewidth)}  {'*' if member.cnt else ' '}{member.name};\n"
                ret += ifdef_pop(1)
                ret += "};\n"
    ret += ifdef_pop(0)

    ret += """
/* messages *******************************************************************/

"""
    ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n"
    namewidth = max(len(msg.name) for msg in typs if isinstance(msg, Message))
    for msg in [msg for msg in typs if isinstance(msg, Message)]:
        ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
        ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n"
    ret += ifdef_pop(0)
    ret += "};\n"

    for msg in [msg for msg in typs if isinstance(msg, Message)]:
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
        ret += c_typename(msg) + " {"
        if not msg.members:
            ret += "};\n"
            continue
        ret += "\n"

        typewidth = max(len(c_typename(m.typ)) for m in msg.members)

        for member in msg.members:
            if member.val:
                continue
            ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
            ret += f"\t{c_typename(member.typ).ljust(typewidth)}  {'*' if member.cnt else ' '}{member.name};\n"
        ret += ifdef_pop(1)
        ret += "};\n"
    ret += ifdef_pop(0)

    return ret


def c_expr(expr: Expr) -> str:
    ret: list[str] = []
    for tok in expr.tokens:
        match tok:
            case ExprOp():
                ret += [tok.op]
            case ExprLit():
                ret += [str(tok.val)]
            case ExprSym(name="end"):
                ret += ["ctx->net_offset"]
            case ExprSym():
                ret += [f"_{tok.name[1:]}_offset"]
    return " ".join(ret)


def gen_c(versions: set[str], typs: list[Type]) -> str:
    global _ifdef_stack
    _ifdef_stack = []

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#include <stdbool.h>
#include <stddef.h>   /* for size_t */
#include <inttypes.h> /* for PRI* macros */
#include <string.h>   /* for memset() */

#include <libmisc/assert.h>

#include <lib9p/9p.h>

#include "internal.h"
"""

    def used(arg: str) -> str:
        return arg

    def unused(arg: str) -> str:
        return f"UNUSED({arg})"

    # strings ##################################################################
    ret += f"""
/* strings ********************************************************************/

static const char *version_strs[{c_ver_enum('NUM')}] = {{
"""
    for ver in ["unknown", *sorted(versions)]:
        if ver in versions:
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
    ret += ifdef_pop(0)
    ret += "};\n"
    ret += f"""
const char *{idprefix}version_str(enum {idprefix}version ver) {{
    assert(0 <= ver && ver < {c_ver_enum('NUM')});
    return version_strs[ver];
}}
"""

    # validate_* ###############################################################
    ret += """
/* validate_* *****************************************************************/

ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
	if (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
		/* If needed-net-size overflowed uint32_t, then
		 * there's no way that actual-net-size will live up to
		 * that.  */
		return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
	if (ctx->net_offset > ctx->net_size)
		return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
	return false;
}

ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
	if (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
		/* If needed-host-size overflowed size_t, then there's
		 * no way that actual-net-size will live up to
		 * that.  */
		return lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
	return false;
}

ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
                                         size_t cnt,
                                         _validate_fn_t item_fn, size_t item_host_size) {
	for (size_t i = 0; i < cnt; i++)
		if (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
			return true;
	return false;
}

#define validate_1(ctx) _validate_size_net(ctx, 1)
#define validate_2(ctx) _validate_size_net(ctx, 2)
#define validate_4(ctx) _validate_size_net(ctx, 4)
#define validate_8(ctx) _validate_size_net(ctx, 8)
"""
    for typ in typs:
        inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))

        if isinstance(typ, Bitfield):
            ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
            verwidth = max(len(ver) for ver in versions)
            for ver in sorted(versions):
                ret += ifdef_push(2, c_ver_ifdef({ver}))
                ret += (
                    f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
                    + "".join(
                        "1" if typ.bit_is_valid(bitname, ver) else "0"
                        for bitname in reversed(typ.bits)
                    )
                    + ",\n"
                )
            ret += ifdef_pop(1)
            ret += "};\n"

        ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"

        if typ.name == "d":  # SPECIAL
            # Optimize... maybe the compiler could figure out to do
            # this, but let's make it obvious.
            ret += "\tuint32_t base_offset = ctx->net_offset;\n"
            ret += "\tif (validate_4(ctx))\n"
            ret += "\t\treturn true;\n"
            ret += "\tuint32_t len = decode_u32le(&ctx->net_bytes[base_offset]);\n"
            ret += "\treturn _validate_size_net(ctx, len) || _validate_size_host(ctx, len);\n"
            ret += "}\n"
            continue
        if typ.name == "s":  # SPECIAL
            # Add an extra nul-byte on the host, and validate UTF-8
            # (also, similar optimization to "d").
            ret += "\tuint32_t base_offset = ctx->net_offset;\n"
            ret += "\tif (validate_2(ctx))\n"
            ret += "\t\treturn true;\n"
            ret += "\tuint16_t len = decode_u16le(&ctx->net_bytes[base_offset]);\n"
            ret += "\tif (_validate_size_net(ctx, len) || _validate_size_host(ctx, ((size_t)len)+1))\n"
            ret += "\t\treturn true;\n"
            ret += "\tif (!is_valid_utf8_without_nul(&ctx->net_bytes[base_offset+2], len))\n"
            ret += '\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
            ret += "\treturn false;\n"
            ret += "}\n"
            continue

        match typ:
            case Number():
                ret += f"\treturn validate_{typ.prim.name}(ctx);\n"
            case Bitfield():
                ret += f"\t if (validate_{typ.static_size}(ctx))\n"
                ret += "\t\treturn true;\n"
                ret += (
                    f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
                )
                ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
                ret += f"\tif (val & ~mask)\n"
                ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
                ret += "\treturn false;\n"
            case Struct():  # and Message()
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                # Pass 1 - declare value variables
                for member in typ.members:
                    if member.max or member.val:
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t{c_typename(member.typ)} {member.name};\n"
                ret += ifdef_pop(1)

                # Pass 2 - declare offset variables
                mark_offset: set[str] = set()
                for member in typ.members:
                    for tok in [*member.max.tokens, *member.val.tokens]:
                        if isinstance(tok, ExprSym) and tok.name.startswith("&"):
                            if tok.name[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
                            mark_offset.add(tok.name[1:])

                # Pass 3 - main pass
                ret += "\treturn false\n"
                prev_size: int | None = None
                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += f"\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c_ver_cond(member.in_versions) + " && "
                    if member.cnt is not None:
                        assert prev_size
                        ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
                    else:
                        if member.max or member.val:
                            ret += "("
                        if member.name in mark_offset:
                            ret += f"({{ _{member.name}_offset = ctx->net_offset; "
                        ret += f"validate_{member.typ.name}(ctx)"
                        if member.name in mark_offset:
                            ret += "; })"
                        if member.max or member.val:
                            bytes = member.static_size
                            assert bytes
                            bits = bytes * 8
                            ret += f" || ({{ {member.name} = decode_u{bits}le(&ctx->net_bytes[ctx->net_offset-{bytes}]); false; }}))"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"
                    prev_size = member.static_size

                # Pass 4 - validate ,max= and ,val= constraints
                for member in typ.members:
                    if member.max:
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
                    if member.val:
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
                        ret += f'\t          lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'

                ret += ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += ifdef_pop(0)

    # unmarshal_* ##############################################################
    ret += """
/* unmarshal_* ****************************************************************/

ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
	*out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 1;
}

ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
	*out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 2;
}

ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
	*out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 4;
}

ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
	*out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 8;
}
"""
    for typ in typs:
        inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
        match typ:
            case Number():
                ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
            case Bitfield():
                ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
            case Struct():
                ret += "\tmemset(out, 0, sizeof(*out));\n"

                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    if member.val:
                        ret += f"\tctx->net_offset += {member.static_size};\n"
                        continue
                    ret += "\t"

                    prefix = "\t"
                    if member.in_versions != typ.in_versions:
                        ret += "if ( " + c_ver_cond(member.in_versions) + " ) "
                        prefix = "\t\t"
                    if member.cnt:
                        if member.in_versions != typ.in_versions:
                            ret += "{\n"
                            ret += prefix
                        ret += f"out->{member.name} = ctx->extra;\n"
                        ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
                        ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
                        if typ.name in ["d", "s"]:  # SPECIAL
                            ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
                        else:
                            ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
                        if member.in_versions != typ.in_versions:
                            ret += "\t}\n"
                    else:
                        ret += (
                            f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
                        )
                if typ.name == "s":  # SPECIAL
                    ret += "\tctx->extra++;\n"
                    ret += "\tout->utf8[out->len] = '\\0';\n"
        ret += ifdef_pop(1)
        ret += "}\n"
    ret += ifdef_pop(0)

    # marshal_* ################################################################
    ret += """
/* marshal_* ******************************************************************/

ALWAYS_INLINE static bool _marshal_too_large(struct _marshal_ctx *ctx) {
	lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
		(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
		ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
		ctx->ctx->max_msg_size);
	return true;
}

ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
	if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
		return _marshal_too_large(ctx);
	ctx->net_bytes[ctx->net_offset] = *val;
	ctx->net_offset += 1;
	return false;
}

ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
	if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
		return _marshal_too_large(ctx);
	encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 2;
	return false;
}

ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
	if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
		return true;
	encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 4;
	return false;
}

ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
	if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
		return true;
	encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
	ctx->net_offset += 8;
	return false;
}
"""
    for typ in typs:
        inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
        argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
        ret += "\n"
        ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
        ret += f"{inline} static bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n"
        match typ:
            case Number():
                ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n"
            case Bitfield():
                ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n"
                ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n"
            case Struct():
                if len(typ.members) == 0:
                    ret += "\treturn false;\n"
                    ret += "}\n"
                    continue

                # Pass 1 - declare offset variables
                mark_offset = set()
                for member in typ.members:
                    if member.val:
                        if member.name not in mark_offset:
                            ret += f"\tuint32_t _{member.name}_offset;\n"
                        mark_offset.add(member.name)
                    for tok in member.val.tokens:
                        if isinstance(tok, ExprSym) and tok.name.startswith("&"):
                            if tok.name[1:] not in mark_offset:
                                ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
                            mark_offset.add(tok.name[1:])

                # Pass 2 - main pass
                ret += "\treturn false\n"
                for member in typ.members:
                    ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                    ret += "\t    || "
                    if member.in_versions != typ.in_versions:
                        ret += "( " + c_ver_cond(member.in_versions) + " && "
                    if member.name in mark_offset:
                        ret += f"({{ _{member.name}_offset = ctx->net_offset; "
                    if member.cnt:
                        ret += "({ bool err = false;\n"
                        ret += f"\t          for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n"
                        ret += "\t          \terr = "
                        if typ.name in ["d", "s"]:  # SPECIAL
                            # Special-case is that we cast from `char` to `uint8_t`.
                            ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n"
                        else:
                            ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n"
                        ret += f"\t          err; }})"
                    elif member.val:
                        # Just increment net_offset, don't actually marsha anything (yet).
                        assert member.static_size
                        ret += (
                            f"({{ ctx->net_offset += {member.static_size}; false; }})"
                        )
                    else:
                        ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
                    if member.name in mark_offset:
                        ret += "; })"
                    if member.in_versions != typ.in_versions:
                        ret += " )"
                    ret += "\n"

                # Pass 3 - marshal ,val= members
                for member in typ.members:
                    if member.val:
                        assert member.static_size
                        ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
                        ret += f"\t    || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n"

                ret += ifdef_pop(1)
                ret += "\t    ;\n"
        ret += "}\n"
    ret += ifdef_pop(0)

    # tables / exports #########################################################
    ret += f"""
/* tables / exports ***********************************************************/

#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{                          \\
		.name      = #typ,                               \\
		.basesize  = sizeof(struct {idprefix}msg_##typ),     \\
		.validate  = validate_##typ,                     \\
		.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,   \\
		.marshal   = (_marshal_fn_t)marshal_##typ,       \\
	}}
#define _NONMSG(num) [num] = {{                                   \\
		.name      = #num,                               \\
	}}

struct _table_version _{idprefix}versions[{c_ver_enum('NUM')}] = {{
"""
    id2typ: dict[int, Message] = {}
    for msg in [msg for msg in typs if isinstance(msg, Message)]:
        id2typ[msg.msgid] = msg

    for ver in ["unknown", *sorted(versions)]:
        if ver != "unknown":
            ret += ifdef_push(1, c_ver_ifdef({ver}))
        ret += f"\t[{c_ver_enum(ver)}] = {{ .msgs = {{\n"

        for n in range(0, 0x100):
            xmsg: Message | None = id2typ.get(n, None)
            if xmsg:
                if ver == "unknown":  # SPECIAL
                    if xmsg.name not in ["Tversion", "Rversion", "Rerror"]:
                        xmsg = None
                else:
                    if ver not in xmsg.in_versions:
                        xmsg = None
            if xmsg:
                ret += f"\t\t_MSG({xmsg.name}),\n"
            else:
                ret += "\t\t_NONMSG(0x{:02X}),\n".format(n)
        ret += "\t}},\n"
    ret += ifdef_pop(0)
    ret += "};\n"

    ret += f"""
FLATTEN bool _{idprefix}validate_stat(struct _validate_ctx *ctx) {{
	return validate_stat(ctx);
}}
FLATTEN void _{idprefix}unmarshal_stat(struct _unmarshal_ctx *ctx, struct lib9p_stat *out) {{
	unmarshal_stat(ctx, out);
}}
FLATTEN bool _{idprefix}marshal_stat(struct _marshal_ctx *ctx, struct lib9p_stat *val) {{
    return marshal_stat(ctx, val);
}}
"""

    ############################################################################
    return ret


################################################################################


class Parser:
    cache: dict[str, tuple[str, list[Type]]] = {}

    def parse_file(self, filename: str) -> tuple[str, list[Type]]:
        filename = os.path.normpath(filename)
        if filename not in self.cache:

            def get_include(other_filename: str) -> tuple[str, list[Type]]:
                return self.parse_file(os.path.join(filename, "..", other_filename))

            self.cache[filename] = parse_file(filename, get_include)
        return self.cache[filename]

    def all(self) -> tuple[set[str], list[Type]]:
        ret_versions: set[str] = set()
        ret_typs: dict[str, Type] = {}
        for version, typs in self.cache.values():
            if version in ret_versions:
                raise ValueError(f"duplicate protocol version {repr(version)}")
            ret_versions.add(version)
            for typ in typs:
                if typ.name in ret_typs:
                    if typ != ret_typs[typ.name]:
                        raise ValueError(f"duplicate type name {repr(typ.name)}")
                else:
                    ret_typs[typ.name] = typ
        return ret_versions, list(ret_typs.values())


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        raise ValueError("requires at least 1 .9p filename")
    parser = Parser()
    for txtname in sys.argv[1:]:
        parser.parse_file(txtname)
    versions, typs = parser.all()
    outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
    with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
        fh.write(gen_h(versions, typs))
    with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
        fh.write(gen_c(versions, typs))