1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
|
# lib9p/protogen/c_marshal.py - Generate C marshal functions
#
# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
import typing
import idl
from . import c9util, cutil, idlutil
# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".
# pylint: disable=unused-variable
__all__ = ["gen_c_marshal"]
# get_offset_expr() ############################################################
class OffsetExpr:
static: int
cond: dict[frozenset[str], "OffsetExpr"]
rep: list[tuple[idlutil.Path, "OffsetExpr"]]
def __init__(self) -> None:
self.static = 0
self.rep = []
self.cond = {}
def add(self, other: "OffsetExpr") -> None:
self.static += other.static
self.rep += other.rep
for k, v in other.cond.items():
if k in self.cond:
self.cond[k].add(v)
else:
self.cond[k] = v
def gen_c(
self,
dsttyp: str,
dstvar: str,
root: str,
indent_depth: int,
loop_depth: int,
) -> str:
oneline: list[str] = []
multiline = ""
if self.static:
oneline.append(str(self.static))
for cnt, sub in self.rep:
if not sub.cond and not sub.rep:
if sub.static == 1:
oneline.append(cnt.c_str(root))
else:
oneline.append(f"({cnt.c_str(root)})*{sub.static}")
continue
loopvar = chr(ord("i") + loop_depth)
multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth + 1)
multiline += f"{'\t'*indent_depth}}}\n"
for vers, sub in self.cond.items():
multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
multiline += f"{'\t'*indent_depth}}}\n"
multiline += cutil.ifdef_pop(indent_depth)
ret = ""
if dsttyp:
if not oneline:
oneline.append("0")
ret += f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
elif oneline:
ret += f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
ret += multiline
return ret
type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]
def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
if not isinstance(typ, idl.Struct):
assert typ.static_size
ret = OffsetExpr()
ret.static = typ.static_size
return ret
class ExprStackItem(typing.NamedTuple):
path: idlutil.Path
expr: OffsetExpr
pop: typing.Callable[[], None]
expr_stack: list[ExprStackItem]
def pop_root() -> None:
assert False
def pop_cond() -> None:
nonlocal expr_stack
key = frozenset(expr_stack[-1].path.elems[-1].in_versions)
if key in expr_stack[-2].expr.cond:
expr_stack[-2].expr.cond[key].add(expr_stack[-1].expr)
else:
expr_stack[-2].expr.cond[key] = expr_stack[-1].expr
expr_stack = expr_stack[:-1]
def pop_rep() -> None:
nonlocal expr_stack
member_path = expr_stack[-1].path
member = member_path.elems[-1]
assert member.cnt
cnt_path = member_path.parent().add(member.cnt)
expr_stack[-2].expr.rep.append((cnt_path, expr_stack[-1].expr))
expr_stack = expr_stack[:-1]
def handle(
path: idlutil.Path,
) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
nonlocal recurse
ret = recurse(path)
if ret != idlutil.WalkCmd.KEEP_GOING:
return ret, None
nonlocal expr_stack
expr_stack_len = len(expr_stack)
def pop() -> None:
nonlocal expr_stack
nonlocal expr_stack_len
while len(expr_stack) > expr_stack_len:
expr_stack[-1].pop()
if path.elems:
child = path.elems[-1]
parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
if child.in_versions < parent.in_versions:
expr_stack.append(
ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_cond)
)
if child.cnt:
expr_stack.append(
ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_rep)
)
if not isinstance(child.typ, idl.Struct):
assert child.typ.static_size
expr_stack[-1].expr.static += child.typ.static_size
return ret, pop
expr_stack = [
ExprStackItem(path=idlutil.Path(typ), expr=OffsetExpr(), pop=pop_root)
]
idlutil.walk(typ, handle)
return expr_stack[0].expr
def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
return idlutil.WalkCmd.KEEP_GOING
def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
def ret(path: idlutil.Path) -> idlutil.WalkCmd:
if len(path.elems) == 1 and path.elems[0].membname == name:
return idlutil.WalkCmd.ABORT
return idlutil.WalkCmd.KEEP_GOING
return ret
# Generate .c ##################################################################
def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
ret = """
/* marshal_* ******************************************************************/
"""
ret += cutil.macro(
"#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
"\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n"
"\t\tctx->net_iov_cnt++;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n"
"\tctx->net_iov_cnt++;\n"
)
ret += cutil.macro(
"#define MARSHAL_BYTES(ctx, data, len)\n"
"\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
"\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
"\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n"
"\tctx->net_copied_size += len;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n"
)
ret += cutil.macro(
"#define MARSHAL_U8LE(ctx, val)\n"
"\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
"\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
"\tctx->net_copied[ctx->net_copied_size] = val;\n"
"\tctx->net_copied_size += 1;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n"
)
ret += cutil.macro(
"#define MARSHAL_U16LE(ctx, val)\n"
"\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
"\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
"\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
"\tctx->net_copied_size += 2;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n"
)
ret += cutil.macro(
"#define MARSHAL_U32LE(ctx, val)\n"
"\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
"\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
"\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
"\tctx->net_copied_size += 4;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n"
)
ret += cutil.macro(
"#define MARSHAL_U64LE(ctx, val)\n"
"\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
"\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
"\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
"\tctx->net_copied_size += 8;\n"
"\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n"
)
class IndentLevel(typing.NamedTuple):
ifdef: bool # whether this is both `{` and `#if`, or just `{`
indent_stack: list[IndentLevel]
def ifdef_lvl() -> int:
return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)
def indent_lvl() -> int:
return len(indent_stack)
max_size: int
def handle(
path: idlutil.Path,
) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
nonlocal ret
nonlocal indent_stack
nonlocal max_size
indent_stack_len = len(indent_stack)
def pop() -> None:
nonlocal ret
nonlocal indent_stack
nonlocal indent_stack_len
while len(indent_stack) > indent_stack_len:
if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
break
ret += f"{'\t'*(indent_lvl()-1)}}}\n"
if indent_stack.pop().ifdef:
ret += cutil.ifdef_pop(ifdef_lvl())
loopdepth = sum(1 for elem in path.elems if elem.cnt)
struct = path.elems[-1].typ if path.elems else path.root
if isinstance(struct, idl.Struct):
offsets: list[str] = []
for member in struct.members:
if not member.val:
continue
for tok in member.val.tokens:
if not isinstance(tok, idl.ExprSym):
continue
if tok.symname == "end" or tok.symname.startswith("&"):
if tok.symname not in offsets:
offsets.append(tok.symname)
for name in offsets:
name_prefix = f"offsetof{''.join('_'+m.membname for m in path.elems)}_"
if name == "end":
if not path.elems:
if max_size > cutil.UINT32_MAX:
ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
else:
ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = needed_size;\n"
continue
recurse: OffsetExprRecursion = go_to_end
else:
assert name.startswith("&")
name = name[1:]
recurse = go_to_tok(name)
expr = get_offset_expr(struct, recurse)
expr_prefix = path.c_str("val->", loopdepth)
if not expr_prefix.endswith(">"):
expr_prefix += "."
ret += expr.gen_c(
"uint32_t",
name_prefix + name,
expr_prefix,
indent_lvl(),
loopdepth,
)
if not path.elems:
return idlutil.WalkCmd.KEEP_GOING, pop
child = path.elems[-1]
parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
if child.in_versions < parent.in_versions:
if line := cutil.ifdef_push(
ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
):
ret += line
ret += (
f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
)
indent_stack.append(IndentLevel(ifdef=True))
if child.cnt:
cnt_path = path.parent().add(child.cnt)
if child.typ.static_size == 1: # SPECIAL (zerocopy)
if path.root.typname == "stat": # SPECIAL (stat)
ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
else:
ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
return idlutil.WalkCmd.KEEP_GOING, pop
loopvar = chr(ord("i") + loopdepth - 1)
ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
indent_stack.append(IndentLevel(ifdef=False))
if not isinstance(child.typ, idl.Struct):
if child.val:
def lookup_sym(sym: str) -> str:
nonlocal path
if sym.startswith("&"):
sym = sym[1:]
return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"
val = c9util.idl_expr(child.val, lookup_sym)
else:
val = path.c_str("val->")
if isinstance(child.typ, idl.Bitfield):
val += f" & {child.typ.typname}_masks[ctx->ctx->version]"
ret += f"{'\t'*indent_lvl()}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
return idlutil.WalkCmd.KEEP_GOING, pop
for typ in typs:
if not (
isinstance(typ, idl.Message) or typ.typname == "stat"
): # SPECIAL (include stat)
continue
assert isinstance(typ, idl.Struct)
ret += "\n"
ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n"
# Pass 1 - check size
max_size = max(typ.max_size(v) for v in typ.in_versions)
if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e)
ret += get_offset_expr(typ, go_to_end).gen_c(
"uint64_t", "needed_size", "val->", 1, 0
)
ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n"
else:
ret += get_offset_expr(typ, go_to_end).gen_c(
"uint32_t", "needed_size", "val->", 1, 0
)
ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
if isinstance(typ, idl.Message): # SPECIAL (disable for stat)
ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
ret += f'\t\t\t"{typ.typname}",\n'
ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
ret += "\t\t\tctx->ctx->max_msg_size);\n"
ret += "\t\treturn true;\n"
ret += "\t}\n"
# Pass 2 - write data
indent_stack = [IndentLevel(ifdef=True)]
idlutil.walk(typ, handle)
while len(indent_stack) > 1:
ret += f"{'\t'*(indent_lvl()-1)}}}\n"
if indent_stack.pop().ifdef:
ret += cutil.ifdef_pop(ifdef_lvl())
# Return
ret += "\treturn false;\n"
ret += "}\n"
ret += cutil.ifdef_pop(0)
return ret
|