diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-09-20 12:16:39 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2024-09-20 12:16:39 -0600 |
commit | 2caf7b05a1de8fcad42e33159008ca8f203f8caf (patch) | |
tree | 88bae960e0bf7741a7144f677915fa48af42e358 | |
parent | 0cc4f3ab82473ca09373a5b1d42223c69bf92fce (diff) |
wip
-rwxr-xr-x | net9p_defs.gen | 76 |
1 files changed, 59 insertions, 17 deletions
diff --git a/net9p_defs.gen b/net9p_defs.gen index 0e75e42..8e8fe27 100755 --- a/net9p_defs.gen +++ b/net9p_defs.gen @@ -240,43 +240,85 @@ def gen_c_check_net_len(msg: Message) -> str: ret += "\t\treturn -EINVAL;\n" return ret - ret += f"\tuint64_t net_offset = 0;\n" - static_acc = 0 - def _gen_c_check_net_len(prefix: str, struct: Struct | Message) -> str: + def has_big_list(typ: Atom | Struct | List | Message) -> bool: + match typ: + case Atom.u8: + return False + case Atom.u16: + return False + case Atom.u32: + return False + case Atom.u64: + return False + case Struct() | Message(): + return any(has_big_list(member.typ) for member in typ.members) + case List(): + sz = static_net_size(typ.typ) + if sz is None: + return False + return sz > 1 + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + inited: bool = False + static_acc: int = 0 + static_acc_name: list[str] = [] + def _gen_c_check_net_len(wsprefix: str, nameprefix: str, struct: Struct | Message) -> str: + nonlocal inited nonlocal static_acc + nonlocal static_acc_name ret: str = "" - + + if has_big_list(struct): + ret += f"{wsprefix}uint32_t sizeof_array;\n" prev_size: int = 0 for member in struct.members: if (sz := static_net_size(member.typ)) is not None: + static_acc_name += [nameprefix+member.name] static_acc += sz prev_size = sz elif isinstance(member.typ, Struct): - ret += _gen_c_check_net_len(prefix, member.typ) + ret += _gen_c_check_net_len(wsprefix, nameprefix+member.name+'.', member.typ) elif isinstance(member.typ, List): - if static_acc: - ret += f"{prefix}net_offset += {static_acc};\n" + if not inited: + ret += f"{wsprefix}uint32_t calced_net_len = {static_acc};" + if static_acc: + ret += f" /* {', '.join(static_acc_name)} */" + ret += "\n" + static_acc_name = [] + static_acc = 0 + inited = True + elif static_acc: + ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n" + static_acc_name = [] static_acc = 0 - ret += f"{prefix}if ((uint64_t)net_len < net_offset)\n" - ret += f"{prefix}\treturn -EINVAL;\n" + ret += f"{wsprefix}/* {nameprefix}{member.name} */\n" + ret += f"{wsprefix}if (net_len < calced_net_len) return -EWRONGLEN;\n" if (sz := static_net_size(member.typ.typ)) is not None: - ret += f"{prefix}net_offset += unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}])*{sz};\n" + if sz == 1: + ret += f"{wsprefix}if (__builtin_add_overflow(calced_net_len, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), &calced_net_len))\n" + else: + ret += f"{wsprefix}if (__builtin_mul_overflow(unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]), {sz}, &sizeof_array)\n" + ret += f"{wsprefix} || __builtin_add_overflow(calced_net_len, sizeof_array, &calced_net_len)\n" + ret += f"{wsprefix}\treturn -EWRONGLEN;\n" else: assert isinstance(member.typ.typ, Struct) - ret += f"{prefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[net_offset-{prev_size}]); i < cnt; i++) {{\n" - ret += _gen_c_check_net_len(prefix + "\t", member.typ.typ) + ret += f"{wsprefix}for (uint{prev_size*8}_t i, cnt = 0, unmarshal_u{prev_size*8}le(&net_bytes[calced_net_len-{prev_size}]); i < cnt; i++) {{\n" + ret += _gen_c_check_net_len(wsprefix + "\t", nameprefix+member.name+'[i].', member.typ.typ) if static_acc: - ret += f"{prefix}net_offset += {static_acc};\n" + ret += f"{wsprefix}\tif (__builtin_add_overflow(calced_net_len, {static_acc}, &calced_net_len)) return -EWRONGLEN; /* {', '.join(static_acc_name)} */\n" + static_acc_name = [] static_acc = 0 - ret += f"{prefix}\n" + ret += f"{wsprefix}\n" else: raise ValueError(f"not a type: {member.typ.__class__.__name__}") return ret - ret += _gen_c_check_net_len("\t", msg) + ret += _gen_c_check_net_len('\t', '.', msg) if static_acc: - ret += f"\tnet_offset += {static_acc};\n" - ret += "\tif ((uint64_t)net_len != net_offset)\n" + ret += f"\tcalced_net_len += {static_acc}; /* {', '.join(static_acc_name)} */\n" + ret += "\t/* final check */\n" + ret += "\tif (net_len != calced_net_len)\n" ret += "\t\treturn -EINVAL;\n" return ret |