summaryrefslogtreecommitdiff
path: root/lib9p/idl.gen
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-10-09 17:01:06 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-10-09 17:01:06 -0600
commitbb5afed7a0eeaf361be1e29b3a3ab8ace2865b39 (patch)
treed311d7c8a658079ee5280d5d057203a5d501091c /lib9p/idl.gen
parentcb8893dd08b7b359f45ef225acd2f6e103d38bba (diff)
lib9p: finish refactor
Diffstat (limited to 'lib9p/idl.gen')
-rwxr-xr-xlib9p/idl.gen589
1 files changed, 297 insertions, 292 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen
index 1f5e48c..8e7639a 100755
--- a/lib9p/idl.gen
+++ b/lib9p/idl.gen
@@ -17,6 +17,7 @@ from typing import Callable, Final, Literal, TypeAlias, TypeVar, cast
# Types ########################################################################
+
class Primitive(enum.Enum):
u8 = 1
u16 = 2
@@ -128,7 +129,7 @@ class StructMember:
# from left-to-right when parsing
cnt: str | None = None
name: str
- typ: 'Type'
+ typ: "Type"
max: Expr
val: Expr
@@ -174,7 +175,7 @@ class Message(Struct):
Type: TypeAlias = Primitive | Number | Bitfield | Struct | Message
-#type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13
+# type Type = Primitive | Number | Bitfield | Struct | Message # Change to this once we have Python 3.13
T = TypeVar("T", Number, Bitfield, Struct, Message)
# Parse *.9p ###################################################################
@@ -234,6 +235,8 @@ def parse_expr(expr: str) -> Expr:
ret = Expr()
for tok in re.split("([-+])", expr):
if tok == "-" or tok == "+":
+ # I, for the life of me, do not understand why I need this
+ # cast() to keep mypy happy.
ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))]
elif re.fullmatch("[0-9]+", tok):
ret.tokens += [ExprLit(int(tok))]
@@ -451,8 +454,10 @@ def parse_file(
# Generate C ###################################################################
+idprefix = "lib9p_"
+
-def c_ver_enum(idprefix: str, ver: str) -> str:
+def c_ver_enum(ver: str) -> str:
return f"{idprefix.upper()}VER_{ver.replace('.', '_')}"
@@ -462,15 +467,13 @@ def c_ver_ifdef(versions: set[str]) -> str:
)
-def c_ver_cond(idprefix: str, versions: set[str]) -> str:
+def c_ver_cond(versions: set[str]) -> str:
if len(versions) == 1:
- return f"(ctx->ctx->version=={c_ver_enum(idprefix, next(v for v in versions))})"
- return (
- "( " + (" || ".join(c_ver_cond(idprefix, {v}) for v in sorted(versions))) + " )"
- )
+ return f"(ctx->ctx->version=={c_ver_enum(next(v for v in versions))})"
+ return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )"
-def c_typename(idprefix: str, typ: Type) -> str:
+def c_typename(typ: Type) -> str:
match typ:
case Primitive():
return f"uint{typ.value*8}_t"
@@ -517,6 +520,7 @@ def ifdef_push(n: int, _newval: str) -> str:
ret += f"#if {newval}\n"
return ret
+
def ifdef_pop(n: int) -> str:
global _ifdef_stack
ret = ""
@@ -527,7 +531,7 @@ def ifdef_pop(n: int) -> str:
return ret
-def gen_h(idprefix: str, versions: set[str], typs: list[Type]) -> str:
+def gen_h(versions: set[str], typs: list[Type]) -> str:
global _ifdef_stack
_ifdef_stack = []
@@ -550,10 +554,10 @@ enum {idprefix}version {{
for ver in fullversions:
if ver in versions:
ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f"\t{c_ver_enum(idprefix, ver)},"
+ ret += f"\t{c_ver_enum(ver)},"
ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
ret += ifdef_pop(0)
- ret += f"\t{c_ver_enum(idprefix, 'NUM')},\n"
+ ret += f"\t{c_ver_enum('NUM')},\n"
ret += "};\n"
ret += "\n"
ret += f"const char *{idprefix}version_str(enum {idprefix}version);\n"
@@ -566,9 +570,9 @@ enum {idprefix}version {{
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
match typ:
case Number():
- ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n"
+ ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
case Bitfield():
- ret += f"typedef {c_typename(idprefix, typ.prim)} {c_typename(idprefix, typ)};\n"
+ ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n"
names = [
*reversed(
[typ.bits[n] or f" {n}" for n in range(0, len(typ.bits))]
@@ -584,28 +588,36 @@ enum {idprefix}version {{
ret += "\n"
elif name.startswith(" "):
ret += ifdef_push(2, c_ver_ifdef(typ.in_versions))
- sp = ' '*(len('# define ')+len(idprefix)+len(typ.name)+1+namewidth+2-len("/* unused"))
- ret += f"/* unused{sp}(({c_typename(idprefix, typ)})(1<<{name[1:]})) */\n"
+ sp = " " * (
+ len("# define ")
+ + len(idprefix)
+ + len(typ.name)
+ + 1
+ + namewidth
+ + 2
+ - len("/* unused")
+ )
+ ret += f"/* unused{sp}(({c_typename(typ)})(1<<{name[1:]})) */\n"
else:
ret += ifdef_push(2, c_ver_ifdef(typ.names[name].in_versions))
if name.startswith("_"):
c_name = f"_{idprefix.upper()}{typ.name.upper()}_{name[1:]}"
else:
c_name = f"{idprefix.upper()}{typ.name.upper()}_{name}"
- sp1 = ' ' if _ifdef_stack[-1] else ''
- sp2 = ' ' if _ifdef_stack[-1] else ' '
- sp3 = ' '*(2+namewidth-len(name))
- ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(idprefix, typ)})({typ.names[name].val}))\n"
+ sp1 = " " if _ifdef_stack[-1] else ""
+ sp2 = " " if _ifdef_stack[-1] else " "
+ sp3 = " " * (2 + namewidth - len(name))
+ ret += f"#{sp1}define{sp2}{c_name}{sp3}(({c_typename(typ)})({typ.names[name].val}))\n"
ret += ifdef_pop(1)
case Struct():
- typewidth = max(len(c_typename(idprefix, m.typ)) for m in typ.members)
+ typewidth = max(len(c_typename(m.typ)) for m in typ.members)
- ret += c_typename(idprefix, typ) + " {\n"
+ ret += c_typename(typ) + " {\n"
for member in typ.members:
if member.val:
continue
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- c_type = c_typename(idprefix, member.typ)
+ c_type = c_typename(member.typ)
if (typ.name in ["d", "s"]) and member.cnt: # SPECIAL
c_type = "char"
ret += f"\t{c_type.ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
@@ -630,19 +642,19 @@ enum {idprefix}version {{
for msg in [msg for msg in typs if isinstance(msg, Message)]:
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(msg.in_versions))
- ret += c_typename(idprefix, msg) + " {"
+ ret += c_typename(msg) + " {"
if not msg.members:
ret += "};\n"
continue
ret += "\n"
- typewidth = max(len(c_typename(idprefix, m.typ)) for m in msg.members)
+ typewidth = max(len(c_typename(m.typ)) for m in msg.members)
for member in msg.members:
if member.val:
continue
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(idprefix, member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
+ ret += f"\t{c_typename(member.typ).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n"
ret += ifdef_pop(1)
ret += "};\n"
ret += ifdef_pop(0)
@@ -665,7 +677,7 @@ def c_expr(expr: Expr) -> str:
return " ".join(ret)
-def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str:
+def gen_c(versions: set[str], typs: list[Type]) -> str:
global _ifdef_stack
_ifdef_stack = []
@@ -692,17 +704,17 @@ def gen_c(idprefix: str, versions: set[str], typs: list[Type]) -> str:
ret += f"""
/* strings ********************************************************************/
-static const char *version_strs[{c_ver_enum(idprefix, 'NUM')}] = {{
+static const char *version_strs[{c_ver_enum('NUM')}] = {{
"""
for ver in ["unknown", *sorted(versions)]:
if ver in versions:
ret += ifdef_push(1, c_ver_ifdef({ver}))
- ret += f'\t[{c_ver_enum(idprefix, ver)}] = "{ver}",\n'
+ ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n'
ret += ifdef_pop(0)
ret += "};\n"
ret += f"""
const char *{idprefix}version_str(enum {idprefix}version ver) {{
- assert(0 <= ver && ver < {c_ver_enum(idprefix, 'NUM')});
+ assert(0 <= ver && ver < {c_ver_enum('NUM')});
return version_strs[ver];
}}
@@ -764,6 +776,23 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
ret += "\n"
ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
+
+ if isinstance(typ, Bitfield):
+ ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n"
+ verwidth = max(len(ver) for ver in versions)
+ for ver in sorted(versions):
+ ret += ifdef_push(2, c_ver_ifdef({ver}))
+ ret += (
+ f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
+ + "".join(
+ "1" if typ.bit_is_valid(bitname, ver) else "0"
+ for bitname in reversed(typ.bits)
+ )
+ + ",\n"
+ )
+ ret += ifdef_pop(1)
+ ret += "};\n"
+
ret += f"static {inline} bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n"
if typ.name == "d": # SPECIAL
@@ -797,28 +826,12 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
case Bitfield():
ret += f"\t if (validate_{typ.static_size}(ctx))\n"
ret += "\t\treturn true;\n"
- ret += f"\tstatic const {c_typename(idprefix, typ)} masks[{c_ver_enum(idprefix, 'NUM')}] = {{\n"
- verwidth = max(len(ver) for ver in versions)
- for ver in sorted(versions):
- ret += ifdef_push(2, c_ver_ifdef({ver}))
- ret += (
- f"\t\t[{c_ver_enum(idprefix, ver)}]{' '*(verwidth-len(ver))} = 0b"
- + "".join(
- "1" if typ.bit_is_valid(bitname, ver) else "0"
- for bitname in reversed(typ.bits)
- )
- + ",\n"
- )
- ret += ifdef_pop(1)
- ret += "\t};\n"
ret += (
- f"\t{c_typename(idprefix, typ)} mask = masks[ctx->ctx->version];\n"
+ f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n"
)
- ret += f"\t{c_typename(idprefix, typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
+ ret += f"\t{c_typename(typ)} val = decode_u{typ.static_size*8}le(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
ret += f"\tif (val & ~mask)\n"
- ret += "\t\treturn lib9p_errorf(ctx->ctx,\n"
- ret += f'\t\t LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8},\n'
- ret += "\t\t val & ~mask);\n"
+ ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
case Struct(): # and Message()
if len(typ.members) == 0:
@@ -826,14 +839,14 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
ret += "}\n"
continue
- # Pass 1
+ # Pass 1 - declare value variables
for member in typ.members:
if member.max or member.val:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\t{c_typename(idprefix, member.typ)} {member.name};\n"
+ ret += f"\t{c_typename(member.typ)} {member.name};\n"
ret += ifdef_pop(1)
- # Pass 2
+ # Pass 2 - declare offset variables
mark_offset: set[str] = set()
for member in typ.members:
for tok in [*member.max.tokens, *member.val.tokens]:
@@ -842,17 +855,17 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
ret += f"\tuint32_t _{tok.name[1:]}_offset;\n"
mark_offset.add(tok.name[1:])
- # Pass 3
+ # Pass 3 - main pass
ret += "\treturn false\n"
prev_size: int | None = None
for member in typ.members:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
ret += f"\t || "
if member.in_versions != typ.in_versions:
- ret += "( " + c_ver_cond(idprefix, member.in_versions) + " && "
+ ret += "( " + c_ver_cond(member.in_versions) + " && "
if member.cnt is not None:
assert prev_size
- ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(idprefix, member.typ)}))"
+ ret += f"_validate_list(ctx, decode_u{prev_size*8}le(&ctx->net_bytes[ctx->net_offset-{prev_size}]), validate_{member.typ.name}, sizeof({c_typename(member.typ)}))"
else:
if member.max or member.val:
ret += "("
@@ -871,252 +884,244 @@ static ALWAYS_INLINE bool _validate_list(struct _validate_ctx *ctx,
ret += "\n"
prev_size = member.static_size
- # Pass 4
+ # Pass 4 - validate ,max= and ,val= constraints
for member in typ.members:
if member.max:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\n\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
- ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
+ ret += f"\t || ({{ uint32_t max = {c_expr(member.max)}; (((uint32_t){member.name}) > max) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu32" > %"PRIu32")", {member.name}, max); }})\n'
if member.val:
ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
- ret += f"\n\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
- ret += f'\n\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'
+ ret += f"\t || ({{ uint32_t exp = {c_expr(member.val)}; (((uint32_t){member.name}) != exp) &&\n"
+ ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu32" != correct:%"PRIu32")", (uint32_t){member.name}, exp); }})\n'
ret += ifdef_pop(1)
ret += "\t ;\n"
ret += "}\n"
ret += ifdef_pop(0)
- # # unmarshal_* ##############################################################
- # ret += """
- # /* unmarshal_* ****************************************************************/
-
- # static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
- # *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 1;
- # }
-
- # static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
- # *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 2;
- # }
-
- # static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
- # *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 4;
- # }
-
- # static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
- # *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 8;
- # }
- # """
- # for typ in typs:
- # inline = (
- # " FLATTEN"
- # if (isinstance(typ, Struct) and typ.msgid is not None)
- # else " ALWAYS_INLINE"
- # )
- # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
- # ret += "\n"
- # ret += f"static{inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *out) {{\n"
- # match typ:
- # case Bitfield():
- # ret += f"\tunmarshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)out);\n"
- # case Struct():
- # ret += "\tmemset(out, 0, sizeof(*out));\n"
-
- # if typ.members:
- # struct_versions = typ.members[0].ver
- # for member in typ.members:
- # if member.valexpr:
- # ret += f"\tctx->net_offset += {member.static_size};\n"
- # continue
- # ret += "\t"
- # prefix = "\t"
- # if member.ver != struct_versions:
- # ret += "if ( " + c_ver_cond(idprefix, member.ver) + " ) "
- # prefix = "\t\t"
- # if member.cnt:
- # if member.ver != struct_versions:
- # ret += f"{{\n{prefix}"
- # ret += f"out->{member.name} = ctx->extra;\n"
- # ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
- # ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
- # if typ.name in ["d", "s"]: # SPECIAL
- # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
- # else:
- # ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
- # if member.ver != struct_versions:
- # ret += "\t}\n"
- # else:
- # ret += (
- # f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
- # )
- # if typ.name == "s": # SPECIAL
- # ret += "\tctx->extra++;\n"
- # ret += "\tout->utf8[out->len] = '\\0';\n"
- # ret += "}\n"
-
- # # marshal_* ################################################################
- # ret += """
- # /* marshal_* ******************************************************************/
-
- # static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) {
- # lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
- # (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
- # ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
- # ctx->ctx->max_msg_size);
- # return true;
- # }
-
- # static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
- # if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
- # return _marshal_too_large(ctx);
- # ctx->net_bytes[ctx->net_offset] = *val;
- # ctx->net_offset += 1;
- # return false;
- # }
-
- # static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
- # if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
- # return _marshal_too_large(ctx);
- # encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 2;
- # return false;
- # }
-
- # static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
- # if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
- # return true;
- # encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 4;
- # return false;
- # }
-
- # static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
- # if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
- # return true;
- # encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
- # ctx->net_offset += 8;
- # return false;
- # }
- # """
- # for typ in typs:
- # inline = (
- # " FLATTEN"
- # if (isinstance(typ, Struct) and typ.msgid is not None)
- # else " ALWAYS_INLINE"
- # )
- # argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
- # ret += "\n"
- # ret += f"static{inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(idprefix, typ)} *{argfn('val')}) {{"
- # match typ:
- # case Bitfield():
- # ret += "\n"
- # ret += f"\treturn marshal_{typ.static_size}(ctx, (uint{typ.static_size*8}_t *)val);\n"
- # case Struct():
- # if len(typ.members) == 0:
- # ret += "\n\treturn false;\n"
- # ret += "}\n"
- # continue
-
- # mark_offset = set()
- # for member in typ.members:
- # if member.valexpr:
- # if member.name not in mark_offset:
- # ret += f"\n\tuint32_t _{member.name}_offset;"
- # mark_offset.add(member.name)
- # for tok in member.valexpr:
- # if (
- # isinstance(tok, ExprVal)
- # and tok.name.startswith("&")
- # and tok.name[1:] not in mark_offset
- # ):
- # ret += f"\n\tuint32_t _{tok.name[1:]}_offset;"
- # mark_offset.add(tok.name[1:])
-
- # prefix0 = "\treturn "
- # prefix1 = "\t || "
- # prefix2 = "\t "
-
- # struct_versions = typ.members[0].ver
- # prefix = prefix0
- # for member in typ.members:
- # ret += f"\n{prefix}"
- # if member.ver != struct_versions:
- # ret += "( " + c_ver_cond(idprefix, member.ver) + " && "
- # if member.name in mark_offset:
- # ret += f"({{ _{member.name}_offset = ctx->net_offset; "
- # if member.cnt:
- # ret += "({"
- # ret += f"\n{prefix2}\tbool err = false;"
- # ret += f"\n{prefix2}\tfor (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)"
- # if typ.name in ["d", "s"]: # SPECIAL
- # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);"
- # else:
- # ret += f"\n{prefix2}\t\terr = marshal_{member.typ.name}(ctx, &val->{member.name}[i]);"
- # ret += f"\n{prefix2}\terr;"
- # ret += f"\n{prefix2}}})"
- # elif member.valexpr:
- # assert member.static_size
- # ret += (
- # f"({{ ctx->net_offset += {member.static_size}; false; }})"
- # )
- # else:
- # ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
- # if member.name in mark_offset:
- # ret += "; })"
- # if member.ver != struct_versions:
- # ret += " )"
- # prefix = prefix1
-
- # for member in typ.members:
- # if member.valexpr:
- # assert member.static_size
- # ret += f"\n{prefix}"
- # ret += f"({{ encode_u{member.static_size*8}le("
- # for tok in member.valexpr:
- # match tok:
- # case ExprOp():
- # ret += f" {tok.op}"
- # case ExprVal(name="end"):
- # ret += " ctx->net_offset"
- # case ExprVal():
- # ret += f" _{tok.name[1:]}_offset"
- # ret += f", &ctx->net_bytes[_{member.name}_offset]); false; }})"
-
- # ret += ";\n"
- # ret += "}\n"
-
- # # vtables ##################################################################
- # ret += f"""
- # /* vtables ********************************************************************/
-
- # #define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\
- # .basesize = sizeof(struct {idprefix}msg_##typ), \\
- # .validate = validate_##typ, \\
- # .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\
- # .marshal = (_marshal_fn_t)marshal_##typ, \\
- # }}
-
- # struct _vtable_version _{idprefix}vtables[{c_ver_enum(idprefix, 'NUM')}] = {{
- # """
-
- # ret += f"\t[{c_ver_enum(idprefix, 'unknown')}] = {{ .msgs = {{\n"
- # for msg in just_structs_msg(typs):
- # if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL
- # ret += f"\t\t_MSG({msg.name}),\n"
- # ret += "\t}},\n"
-
- # for ver in sorted(versions):
- # ret += f"\t[{c_ver_enum(idprefix, ver)}] = {{ .msgs = {{\n"
- # for msg in just_structs_msg(typs):
- # if ver not in msg.msgver:
- # continue
- # ret += f"\t\t_MSG({msg.name}),\n"
- # ret += "\t}},\n"
- # ret += "};\n"
+ # unmarshal_* ##############################################################
+ ret += """
+/* unmarshal_* ****************************************************************/
+
+static ALWAYS_INLINE void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) {
+ *out = decode_u8le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 1;
+}
+
+static ALWAYS_INLINE void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) {
+ *out = decode_u16le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 2;
+}
+
+static ALWAYS_INLINE void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) {
+ *out = decode_u32le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 4;
+}
+
+static ALWAYS_INLINE void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) {
+ *out = decode_u64le(&ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 8;
+}
+"""
+ for typ in typs:
+ inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ ret += "\n"
+ ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
+ ret += f"static {inline} void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n"
+ match typ:
+ case Number():
+ ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
+ case Bitfield():
+ ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n"
+ case Struct():
+ ret += "\tmemset(out, 0, sizeof(*out));\n"
+
+ for member in typ.members:
+ ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
+ if member.val:
+ ret += f"\tctx->net_offset += {member.static_size};\n"
+ continue
+ ret += "\t"
+
+ prefix = "\t"
+ if member.in_versions != typ.in_versions:
+ ret += "if ( " + c_ver_cond(member.in_versions) + " ) "
+ prefix = "\t\t"
+ if member.cnt:
+ if member.in_versions != typ.in_versions:
+ ret += "{\n"
+ ret += prefix
+ ret += f"out->{member.name} = ctx->extra;\n"
+ ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt};\n"
+ ret += f"{prefix}for (typeof(out->{member.cnt}) i = 0; i < out->{member.cnt}; i++)\n"
+ if typ.name in ["d", "s"]: # SPECIAL
+ ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n"
+ else:
+ ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n"
+ if member.in_versions != typ.in_versions:
+ ret += "\t}\n"
+ else:
+ ret += (
+ f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n"
+ )
+ if typ.name == "s": # SPECIAL
+ ret += "\tctx->extra++;\n"
+ ret += "\tout->utf8[out->len] = '\\0';\n"
+ ret += ifdef_pop(1)
+ ret += "}\n"
+ ret += ifdef_pop(0)
+
+ # marshal_* ################################################################
+ ret += """
+/* marshal_* ******************************************************************/
+
+static ALWAYS_INLINE bool _marshal_too_large(struct _marshal_ctx *ctx) {
+ lib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")",
+ (ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message",
+ ctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"),
+ ctx->ctx->max_msg_size);
+ return true;
+}
+
+static ALWAYS_INLINE bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) {
+ if (ctx->net_offset + 1 > ctx->ctx->max_msg_size)
+ return _marshal_too_large(ctx);
+ ctx->net_bytes[ctx->net_offset] = *val;
+ ctx->net_offset += 1;
+ return false;
+}
+
+static ALWAYS_INLINE bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) {
+ if (ctx->net_offset + 2 > ctx->ctx->max_msg_size)
+ return _marshal_too_large(ctx);
+ encode_u16le(*val, &ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 2;
+ return false;
+}
+
+static ALWAYS_INLINE bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) {
+ if (ctx->net_offset + 4 > ctx->ctx->max_msg_size)
+ return true;
+ encode_u32le(*val, &ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 4;
+ return false;
+}
+
+static ALWAYS_INLINE bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) {
+ if (ctx->net_offset + 8 > ctx->ctx->max_msg_size)
+ return true;
+ encode_u64le(*val, &ctx->net_bytes[ctx->net_offset]);
+ ctx->net_offset += 8;
+ return false;
+}
+"""
+ for typ in typs:
+ inline = "FLATTEN" if isinstance(typ, Message) else "ALWAYS_INLINE"
+ argfn = unused if (isinstance(typ, Struct) and not typ.members) else used
+ ret += "\n"
+ ret += ifdef_push(1, c_ver_ifdef(typ.in_versions))
+ ret += f"static {inline} bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n"
+ match typ:
+ case Number():
+ ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n"
+ case Bitfield():
+ ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n"
+ ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n"
+ case Struct():
+ if len(typ.members) == 0:
+ ret += "\treturn false;\n"
+ ret += "}\n"
+ continue
+
+ # Pass 1 - declare offset variables
+ mark_offset = set()
+ for member in typ.members:
+ if member.val:
+ if member.name not in mark_offset:
+ ret += f"\tuint32_t _{member.name}_offset;\n"
+ mark_offset.add(member.name)
+ for tok in member.val.tokens:
+ if isinstance(tok, ExprSym) and tok.name.startswith("&"):
+ if tok.name[1:] not in mark_offset:
+ ret += f"\n\tuint32_t _{tok.name[1:]}_offset;"
+ mark_offset.add(tok.name[1:])
+
+ # Pass 2 - main pass
+ ret += "\treturn false\n"
+ for member in typ.members:
+ ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
+ ret += "\t || "
+ if member.in_versions != typ.in_versions:
+ ret += "( " + c_ver_cond(member.in_versions) + " && "
+ if member.name in mark_offset:
+ ret += f"({{ _{member.name}_offset = ctx->net_offset; "
+ if member.cnt:
+ ret += "({ bool err = false;\n"
+ ret += f"\t for (typeof(val->{member.cnt}) i = 0; i < val->{member.cnt} && !err; i++)\n"
+ ret += "\t \terr = "
+ if typ.name in ["d", "s"]: # SPECIAL
+ # Special-case is that we cast from `char` to `uint8_t`.
+ ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n"
+ else:
+ ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n"
+ ret += f"\t err; }})"
+ elif member.val:
+ # Just increment net_offset, don't actually marsha anything (yet).
+ assert member.static_size
+ ret += (
+ f"({{ ctx->net_offset += {member.static_size}; false; }})"
+ )
+ else:
+ ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})"
+ if member.name in mark_offset:
+ ret += "; })"
+ if member.in_versions != typ.in_versions:
+ ret += " )"
+ ret += "\n"
+
+ # Pass 3 - marshal ,val= members
+ for member in typ.members:
+ if member.val:
+ assert member.static_size
+ ret += ifdef_push(2, c_ver_ifdef(member.in_versions))
+ ret += f"\t || ({{ encode_u{member.static_size*8}le({c_expr(member.val)}, &ctx->net_bytes[_{member.name}_offset]); false; }})\n"
+
+ ret += ifdef_pop(1)
+ ret += "\t ;\n"
+ ret += "}\n"
+
+ # vtables ##################################################################
+ ret += f"""
+/* vtables ********************************************************************/
+
+#define _MSG(typ) [{idprefix.upper()}TYP_##typ] = {{ \\
+ .basesize = sizeof(struct {idprefix}msg_##typ), \\
+ .validate = validate_##typ, \\
+ .unmarshal = (_unmarshal_fn_t)unmarshal_##typ, \\
+ .marshal = (_marshal_fn_t)marshal_##typ, \\
+ }}
+
+struct _vtable_version _{idprefix}vtables[{c_ver_enum('NUM')}] = {{
+"""
+
+ ret += f"\t[{c_ver_enum('unknown')}] = {{ .msgs = {{\n"
+ for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ if msg.name in ["Tversion", "Rversion", "Rerror"]: # SPECIAL
+ ret += f"\t\t_MSG({msg.name}),\n"
+ ret += "\t}},\n"
+
+ for ver in sorted(versions):
+ ret += ifdef_push(1, c_ver_ifdef({ver}))
+ ret += f"\t[{c_ver_enum(ver)}] = {{ .msgs = {{\n"
+ for msg in [msg for msg in typs if isinstance(msg, Message)]:
+ if ver not in msg.in_versions:
+ continue
+ ret += f"\t\t_MSG({msg.name}),\n"
+ ret += "\t}},\n"
+ ret += ifdef_pop(0)
+ ret += "};\n"
############################################################################
return ret
@@ -1165,6 +1170,6 @@ if __name__ == "__main__":
versions, typs = parser.all()
outdir = os.path.normpath(os.path.join(sys.argv[0], ".."))
with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh:
- fh.write(gen_h("lib9p_", versions, typs))
+ fh.write(gen_h(versions, typs))
with open(os.path.join(outdir, "9p.generated.c"), "w") as fh:
- fh.write(gen_c("lib9p_", versions, typs))
+ fh.write(gen_c(versions, typs))