summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2024-09-20 12:16:39 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2024-09-20 12:16:39 -0600
commit2caf7b05a1de8fcad42e33159008ca8f203f8caf (patch)
tree88bae960e0bf7741a7144f677915fa48af42e358
parent0cc4f3ab82473ca09373a5b1d42223c69bf92fce (diff)
wip
-rwxr-xr-xnet9p_defs.gen76
1 files changed, 59 insertions, 17 deletions
diff --git a/net9p_defs.gen b/net9p_defs.gen
index 0e75e42..8e8fe27 100755
--- a/net9p_defs.gen
+++ b/net9p_defs.gen
@@ -240,43 +240,85 @@ def gen_c_check_net_len(msg: Message) -> str:
ret += "\t\treturn -EINVAL;\n"
return ret
- ret += f"\tuint64_t net_offset = 0;\n"
- static_acc = 0
- def _gen_c_check_net_len(prefix: str, struct: Struct | Message) -> str:
+ def has_big_list(typ: Atom | Struct | List | Message) -> bool:
+ match typ:
+ case Atom.u8:
+ return False
+ case Atom.u16:
+ return False
+ case Atom.u32:
+ return False
+ case Atom.u64:
+ return False
+ case Struct() | Message():
+ return any(has_big_list(member.typ) for member in typ.members)
+ case List():
+ sz = static_net_size(typ.typ)
+ if sz is None:
+ return False
+ return sz > 1
+ case _:
+ raise ValueError(f"not a type: {typ.__class__.__name__}")
+
+ inited: bool = False
+ static_acc: int = 0
+ static_acc_name: list[str] = []
+ def _gen_c_check_net_len(wsprefix: str, nameprefix: str, struct: Struct | Message) -> str:
+ nonlocal inited
nonlocal static_acc
+ nonlocal static_acc_name
ret: str = ""
-
+
+ if has_big_list(struct):
+ ret += f"{wsprefix}uint32_t sizeof_array;\n"
prev_size: int = 0
for member in struct.members:
if (sz := static_net_size(member.typ)) is not None:
+ static_acc_name += [nameprefix+member.name]
static_acc += sz
prev_size = sz
elif isinstance(member.typ, Struct):
- ret += _gen_c_check_net_len(prefix, member.typ)
+ ret += _gen_c_check_net_len(wsprefix, nameprefix+member.name+'.', member.typ)
elif isinstance(member.typ, List):
- if static_acc:
- ret += f"{prefix}net_offset += {static_acc};\n"
+ if not inited:
+ ret += f"{wsprefix}uint32_t calced_net_len = {static_acc};"
+ if static_acc:
+ ret += f" /* {', '.join(static_acc_name)} */"
+ ret += "\n"
+ static_acc_name = []
+ static_acc = 0
+ inited = True
+ elif static_acc:
+ ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n"
+ static_acc_name = []
static_acc = 0
- ret += f"{prefix}if ((uint64_t)net_len < net_offset)\n"
- ret += f"{prefix}\treturn -EINVAL;\n"
+ ret += f"{wsprefix}/* {nameprefix}{member.name} */\n"
+ ret += f"{wsprefix}if (net_len < calced_net_len) return -EWRONGLEN;\n"
if (sz := static_net_size(member.typ.typ)) is not None:
- ret += f"{prefix}net_offset += unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}])*{sz};\n"
+ if sz == 1:
+ ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), &calced_net_len))\n"
+ else:
+ ret += f"{wsprefix}if (__builtin_mul_overflow(unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), {sz}, &sizeof_array)\n"
+ ret += f"{wsprefix} || __builtin_add_overflow(calced_net_len, sizeof_array, &calced_net_len)\n"
+ ret += f"{wsprefix}\treturn -EWRONGLEN;\n"
else:
assert isinstance(member.typ.typ, Struct)
- ret += f"{prefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}]); i < cnt; i++) {{\n"
- ret += _gen_c_check_net_len(prefix + "\t", member.typ.typ)
+ ret += f"{wsprefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]); i < cnt; i++) {{\n"
+ ret += _gen_c_check_net_len(wsprefix + "\t", nameprefix+member.name+'[i].', member.typ.typ)
if static_acc:
- ret += f"{prefix}net_offset += {static_acc};\n"
+ ret += f"{wsprefix}\tif (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n"
+ static_acc_name = []
static_acc = 0
- ret += f"{prefix}\n"
+ ret += f"{wsprefix}\n"
else:
raise ValueError(f"not a type: {member.typ.__class__.__name__}")
return ret
- ret += _gen_c_check_net_len("\t", msg)
+ ret += _gen_c_check_net_len('\t', '.', msg)
if static_acc:
- ret += f"\tnet_offset += {static_acc};\n"
- ret += "\tif ((uint64_t)net_len != net_offset)\n"
+ ret += f"\tcalced_net_len += {static_acc}; /* {', '.join(static_acc_name)} */\n"
+ ret += "\t/* final check */\n"
+ ret += "\tif (net_len != calced_net_len)\n"
ret += "\t\treturn -EINVAL;\n"
return ret