1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
|
# lib9p/protogen/c_validate.py - Generate C validation functions
#
# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later
import idl
from . import c9util, cutil, idlutil
# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".
# pylint: disable=unused-variable
__all__ = ["gen_c_validate"]
def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool:
return bool(member.max or member.val or any(m.cnt == member for m in typ.members))
def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
ret = """
/* validate_* *****************************************************************/
LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) {
\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset))
\t\t/* If needed-net-size overflowed uint32_t, then
\t\t * there's no way that actual-net-size will live up to
\t\t * that. */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\tif (ctx->net_offset > ctx->net_size)
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}
LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) {
\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra))
\t\t/* If needed-host-size overflowed size_t, then there's
\t\t * no way that actual-net-size will live up to
\t\t * that. */
\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content");
\treturn false;
}
LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx,
size_t cnt,
_validate_fn_t item_fn, size_t item_host_size) {
\tfor (size_t i = 0; i < cnt; i++)
\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx))
\t\t\treturn true;
\treturn false;
}
LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); }
LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); }
LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); }
LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); }
"""
for typ in idlutil.topo_sorted(typs):
inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE"
argfn = (
c9util.arg_unused
if (isinstance(typ, idl.Struct) and not typ.members)
else c9util.arg_used
)
ret += "\n"
ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n"
match typ:
case idl.Number():
ret += f"\treturn validate_{typ.prim.typname}(ctx);\n"
case idl.Bitfield():
ret += f"\t if (validate_{typ.static_size}(ctx))\n"
ret += "\t\treturn true;\n"
ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n"
if typ.static_size == 1:
ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n"
else:
ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n"
ret += "\tif (val & ~mask)\n"
ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n'
ret += "\treturn false;\n"
case idl.Struct(): # and idl.Message()
if len(typ.members) == 0:
ret += "\treturn false;\n"
ret += "}\n"
continue
# Pass 1 - declare value variables
for member in typ.members:
if should_save_value(typ, member):
ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
ret += f"\t{c9util.typename(member.typ)} {member.membname};\n"
ret += cutil.ifdef_pop(1)
# Pass 2 - declare offset variables
mark_offset: set[str] = set()
for member in typ.members:
for tok in [*member.max.tokens, *member.val.tokens]:
if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"):
if tok.symname[1:] not in mark_offset:
ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n"
mark_offset.add(tok.symname[1:])
# Pass 3 - main pass
ret += "\treturn false\n"
for member in typ.members:
ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
ret += "\t || "
if member.in_versions != typ.in_versions:
ret += "( " + c9util.ver_cond(member.in_versions) + " && "
if member.cnt is not None:
if member.typ.static_size == 1: # SPECIAL (zerocopy)
ret += f"_validate_size_net(ctx, {member.cnt.membname})"
else:
ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))"
if typ.typname == "s": # SPECIAL (string)
ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })'
else:
if should_save_value(typ, member):
ret += "("
if member.membname in mark_offset:
ret += f"({{ _{member.membname}_offset = ctx->net_offset; "
ret += f"validate_{member.typ.typname}(ctx)"
if member.membname in mark_offset:
ret += "; })"
if should_save_value(typ, member):
nbytes = member.static_size
assert nbytes
if nbytes == 1:
ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))"
else:
ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))"
if member.in_versions != typ.in_versions:
ret += " )"
ret += "\n"
# Pass 4 - validate ,max= and ,val= constraints
for member in typ.members:
def lookup_sym(sym: str) -> str:
match sym:
case "end":
return "ctx->net_offset"
case _:
assert sym.startswith("&")
return f"_{sym[1:]}_offset"
if member.max:
assert member.static_size
nbits = member.static_size * 8
ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n"
ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n'
if member.val:
assert member.static_size
nbits = member.static_size * 8
ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n"
ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n'
ret += cutil.ifdef_pop(1)
ret += "\t ;\n"
ret += "}\n"
ret += cutil.ifdef_pop(0)
return ret
|