diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-15 14:04:10 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-22 19:18:38 -0600 |
commit | d912a4d79ed9e51e5dfcc24e6445c1de7dbb1a30 (patch) | |
tree | 4d9604cdfa7269cf827b87301e9a685a55c6da2e | |
parent | 185c3329145959433b8b805de5f114b66b8fcaee (diff) |
lib9p: idl.gen: Have a separate type that excludes idl.Primitive
-rwxr-xr-x | lib9p/idl.gen | 14 | ||||
-rw-r--r-- | lib9p/idl/__init__.py | 17 |
2 files changed, 16 insertions, 15 deletions
diff --git a/lib9p/idl.gen b/lib9p/idl.gen index 779b6d5..b75ffd6 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -147,8 +147,8 @@ def ifdef_pop(n: int) -> str: # topo_sorted() ################################################################ -def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]: - ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter() +def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]: + ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter() for typ in typs: match typ: case idl.Number(): @@ -375,7 +375,7 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: # Generate .h ################################################################## -def gen_h(versions: set[str], typs: list[idl.Type]) -> str: +def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: global _ifdef_stack _ifdef_stack = [] @@ -451,7 +451,7 @@ enum {idprefix}version {{ """ def per_version_comment( - typ: idl.Type, fn: typing.Callable[[idl.Type, str], str] + typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str] ) -> str: lines: dict[str, str] = {} for version in sorted(typ.in_versions): @@ -471,7 +471,7 @@ enum {idprefix}version {{ ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - def sum_size(typ: idl.Type, version: str) -> str: + def sum_size(typ: idl.UserType, version: str) -> str: sz = get_buffer_size(typ, version) assert ( sz.min_size <= sz.exp_size @@ -643,7 +643,7 @@ enum {idprefix}version {{ # Generate .c ################################################################## -def gen_c(versions: set[str], typs: list[idl.Type]) -> str: +def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: global _ifdef_stack _ifdef_stack = [] @@ -1088,7 +1088,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o type OffsetExprRecursion = typing.Callable[[Path], WalkCmd] - def get_offset_expr(typ: idl.Type, recurse: OffsetExprRecursion) -> OffsetExpr: + def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: if not isinstance(typ, idl.Struct): assert typ.static_size ret = OffsetExpr() diff --git a/lib9p/idl/__init__.py b/lib9p/idl/__init__.py index 042a438..04e1791 100644 --- a/lib9p/idl/__init__.py +++ b/lib9p/idl/__init__.py @@ -271,6 +271,7 @@ class Message(Struct): type Type = Primitive | Number | Bitfield | Struct | Message +type UserType = Number | Bitfield | Struct | Message T = typing.TypeVar("T", Number, Bitfield, Struct, Message) # Parse ######################################################################## @@ -417,8 +418,8 @@ re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg def parse_file( - filename: str, get_include: typing.Callable[[str], tuple[str, list[Type]]] -) -> tuple[str, list[Type]]: + filename: str, get_include: typing.Callable[[str], tuple[str, list[UserType]]] +) -> tuple[str, list[UserType]]: version: str | None = None env: dict[str, Type] = { "1": Primitive.u8, @@ -577,7 +578,7 @@ def parse_file( if not version: raise SyntaxError("must have exactly 1 version line") - typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] + typs: list[UserType] = [x for x in env.values() if not isinstance(x, Primitive)] for typ in [typ for typ in typs if isinstance(typ, Struct)]: valid_syms = [ @@ -607,21 +608,21 @@ def parse_file( class Parser: - cache: dict[str, tuple[str, list[Type]]] = {} + cache: dict[str, tuple[str, list[UserType]]] = {} - def parse_file(self, filename: str) -> tuple[str, list[Type]]: + def parse_file(self, filename: str) -> tuple[str, list[UserType]]: filename = os.path.normpath(filename) if filename not in self.cache: - def get_include(other_filename: str) -> tuple[str, list[Type]]: + def get_include(other_filename: str) -> tuple[str, list[UserType]]: return self.parse_file(os.path.join(filename, "..", other_filename)) self.cache[filename] = parse_file(filename, get_include) return self.cache[filename] - def all(self) -> tuple[set[str], list[Type]]: + def all(self) -> tuple[set[str], list[UserType]]: ret_versions: set[str] = set() - ret_typs: dict[str, Type] = {} + ret_typs: dict[str, UserType] = {} for version, typs in self.cache.values(): if version in ret_versions: raise ValueError(f"duplicate protocol version {repr(version)}") |