diff options
-rw-r--r-- | .editorconfig | 2 | ||||
-rw-r--r-- | .pylintrc | 40 | ||||
-rw-r--r-- | GNUmakefile | 1 | ||||
-rw-r--r-- | build-aux/requirements.txt | 3 | ||||
-rwxr-xr-x | build-aux/stack.c.gen | 92 | ||||
-rw-r--r-- | gdb-helpers/libcr.py | 59 | ||||
-rw-r--r-- | gdb-helpers/rp2040.py | 10 | ||||
-rwxr-xr-x | lib9p/idl.gen | 270 | ||||
-rw-r--r-- | lib9p/idl/__init__.py | 148 | ||||
-rwxr-xr-x | lib9p/include/lib9p/linux-errno.h.gen | 12 |
10 files changed, 358 insertions, 279 deletions
diff --git a/.editorconfig b/.editorconfig index 969313a..69fefd5 100644 --- a/.editorconfig +++ b/.editorconfig @@ -52,7 +52,7 @@ _mode = pip [**/Documentation/**.txt] _mode = man-cat -[{.editorconfig,.gitmodules}] +[{.editorconfig,.gitmodules,.pylintrc}] _mode = ini [.gitignore] diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..944973f --- /dev/null +++ b/.pylintrc @@ -0,0 +1,40 @@ +# .pylintrc - Configuration for Pylint +# +# Copyright (C) 2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +[MAIN] +analyse-fallback-blocks=yes +enable-all-extensions=yes +fail-on=all +#init-hook='sys.path.insert(0, os.path.normpath(os.path.join(__file__, "..")))' + +[MESSAGES CONTROL] + +disable=missing-module-docstring, + missing-class-docstring, + missing-function-docstring, + fixme, + line-too-long, + unused-argument, + too-few-public-methods, + invalid-name, + too-many-lines, + too-many-locals, + too-many-statements, + too-many-nested-blocks, + too-many-branches, + too-many-instance-attributes, + too-many-arguments, + too-many-positional-arguments, + too-many-return-statements, + global-statement, + import-outside-toplevel + +[REPORTS] +reports=no +score=no + +[VARIABLES] +allow-global-unused-variables=no +init-import=yes diff --git a/GNUmakefile b/GNUmakefile index ee4026f..339ae4b 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -122,6 +122,7 @@ lint/python3: lint/%: build-aux/venv ./build-aux/venv/bin/mypy --strict --scripts-are-modules $(sources_$*) ./build-aux/venv/bin/black --check $(sources_$*) ./build-aux/venv/bin/isort --check $(sources_$*) + ./build-aux/venv/bin/pylint $(sources_$*) ! grep -nh 'SPECIAL$$' -- lib9p/idl.gen lint/c: lint/%: build-aux/lint-h build-aux/get-dscname ./build-aux/lint-h $(filter %.h,$(sources_$*)) diff --git a/build-aux/requirements.txt b/build-aux/requirements.txt index 43a13be..fb76559 100644 --- a/build-aux/requirements.txt +++ b/build-aux/requirements.txt @@ -1,9 +1,10 @@ # build-aux/requirements.txt - List of Python dev requirements # -# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> # SPDX-License-Identifier: AGPL-3.0-or-later mypy types-gdb>=15.0.0.20241204 # https://github.com/python/typeshed/pull/13169 black isort +pylint diff --git a/build-aux/stack.c.gen b/build-aux/stack.c.gen index 1a96379..5a983cb 100755 --- a/build-aux/stack.c.gen +++ b/build-aux/stack.c.gen @@ -89,6 +89,10 @@ def parse_vcg(reader: typing.TextIO) -> typing.Iterator[VCGElem]: elem.attrs[k] = v pos = m.end() + del _raise + del pos + del line + del lineno yield elem @@ -162,7 +166,7 @@ class Node: def synthetic_node( - name: str, nstatic: int, calls: typing.Collection[str] = set() + name: str, nstatic: int, calls: typing.Collection[str] = frozenset() ) -> Node: n = Node() @@ -200,7 +204,7 @@ class Application(typing.Protocol): def indirect_callees( self, elem: VCGElem ) -> tuple[typing.Collection[QName], bool]: ... - def skip_call(self, chain: list[QName], funcname: QName) -> bool: ... + def skip_call(self, chain: typing.Sequence[QName], funcname: QName) -> bool: ... def analyze( @@ -219,8 +223,8 @@ def analyze( flags=re.MULTILINE, ) - graph: dict[QName, Node] = dict() - qualified: dict[BaseName, set[QName]] = dict() + graph: dict[QName, Node] = {} + qualified: dict[BaseName, set[QName]] = {} def handle_elem(elem: VCGElem) -> None: match elem.typ: @@ -288,7 +292,7 @@ def analyze( raise ValueError(f"unknown elem type {repr(elem.typ)}") for ci_fname in ci_fnames: - with open(ci_fname, "r") as fh: + with open(ci_fname, "r", encoding="utf-8") as fh: for elem in parse_vcg(fh): handle_elem(elem) @@ -329,7 +333,9 @@ def analyze( track_inclusion: bool = True def nstatic( - orig_funcname: QName, chain: list[QName] = [], missing_ok: bool = False + orig_funcname: QName, + chain: typing.Sequence[QName] = (), + missing_ok: bool = False, ) -> int: nonlocal dbg nonlocal track_inclusion @@ -350,7 +356,7 @@ def analyze( return 0 if len(chain) == cfg_max_call_depth: - raise ValueError(f"max call depth exceeded: {chain+[funcname]}") + raise ValueError(f"max call depth exceeded: {[*chain, funcname]}") node = graph[funcname] if dbg: @@ -363,13 +369,13 @@ def analyze( [ 0, *[ - nstatic(call, chain + [funcname], missing_ok) + nstatic(call, [*chain, funcname], missing_ok) for call, missing_ok in node.calls.items() ], ] ) - groups: dict[str, AnalyzeResultGroup] = dict() + groups: dict[str, AnalyzeResultGroup] = {} for grp_name, grp_filter in app_func_filters.items(): rows: dict[QName, AnalyzeResultVal] = {} for funcname in graph: @@ -395,7 +401,7 @@ def read_source(location: str) -> str: filename = m.group("filename") row = int(m.group("row")) - 1 col = int(m.group("col")) - 1 - with open(m.group("filename"), "r") as fh: + with open(filename, "r", encoding="utf-8") as fh: return fh.readlines()[row][col:].rstrip() @@ -430,7 +436,7 @@ class Plugin(typing.Protocol): def indirect_callees( self, loc: str, line: str ) -> tuple[typing.Collection[QName], bool] | None: ... - def skip_call(self, chain: list[QName], call: QName) -> bool: ... + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: ... class PluginApplication: @@ -464,7 +470,7 @@ class PluginApplication: placeholder += " at " + self._location_xform(elem.attrs.get("label", "")) return [QName(placeholder)], False - def skip_call(self, chain: list[QName], funcname: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], funcname: QName) -> bool: for plugin in self._plugins: if plugin.skip_call(chain, funcname): return True @@ -499,7 +505,7 @@ class CmdPlugin: return [QName("get_root")], False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -513,7 +519,7 @@ class LibObjPlugin: re_lo_iface = re.compile(r"^\s*#\s*define\s+(?P<name>\S+)_LO_IFACE") re_lo_func = re.compile(r"LO_FUNC *\([^,]*, *(?P<name>[^,) ]+) *[,)]") for fname in arg_c_fnames: - with open(fname, "r") as fh: + with open(fname, "r", encoding="utf-8") as fh: while line := fh.readline(): if m := re_lo_iface.match(line): iface_name = m.group("name") @@ -534,15 +540,15 @@ class LibObjPlugin: r"^LO_IMPLEMENTATION_[HC]\s*\(\s*(?P<iface>[^, ]+)\s*,\s*(?P<impl_typ>[^,]+)\s*,\s*(?P<impl_name>[^, ]+)\s*[,)].*" ) for fname in arg_c_fnames: - with open(fname, "r") as fh: + with open(fname, "r", encoding="utf-8") as fh: for line in fh: line = line.strip() if m := re_lo_implementation.match(line): implementations[m.group("iface")].add(m.group("impl_name")) objcalls: dict[str, set[QName]] = {} # method_name => {method_impls} - for iface_name in ifaces: - for method_name in ifaces[iface_name]: + for iface_name, iface in ifaces.items(): + for method_name in iface: if method_name not in objcalls: objcalls[method_name] = set() for impl_name in implementations[iface_name]: @@ -576,7 +582,7 @@ class LibObjPlugin: ], False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -645,7 +651,7 @@ class LibHWPlugin: ], False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -667,7 +673,7 @@ class LibCRPlugin: ) -> tuple[typing.Collection[QName], bool] | None: return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -696,7 +702,7 @@ class LibCRIPCPlugin: ], False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -732,7 +738,7 @@ class Lib9PPlugin: def config_h_get(varname: str) -> int | None: if config_h_fname: - with open(config_h_fname, "r") as fh: + with open(config_h_fname, "r", encoding="utf-8") as fh: for line in fh: line = line.rstrip() if line.startswith("#define"): @@ -753,7 +759,7 @@ class Lib9PPlugin: r"^\s*\[LIB9P_TYP_T[^]]+\]\s*=\s*\(tmessage_handler\)\s*(?P<handler>\S+),\s*$" ) tmessage_handlers = set() - with open(lib9p_srv_c_fname, "r") as fh: + with open(lib9p_srv_c_fname, "r", encoding="utf-8") as fh: for line in fh: line = line.rstrip() if m := re_tmessage_handler.fullmatch(line): @@ -763,7 +769,7 @@ class Lib9PPlugin: lib9p_msgs: set[str] = set() if lib9p_generated_c_fname: re_lib9p_msg_entry = re.compile(r"^\s*_MSG_(?:[A-Z]+)\((?P<typ>\S+)\),$") - with open(lib9p_generated_c_fname, "r") as fh: + with open(lib9p_generated_c_fname, "r", encoding="utf-8") as fh: for line in fh: line = line.rstrip() if m := re_lib9p_msg_entry.fullmatch(line): @@ -776,7 +782,7 @@ class Lib9PPlugin: assert self.CONFIG_9P_SRV_MAX_REQS if "read" in str(name.base()): return self._CONFIG_9P_NUM_SOCKS - elif "write" in str(name.base()): + if "write" in str(name.base()): return self._CONFIG_9P_NUM_SOCKS * self.CONFIG_9P_SRV_MAX_REQS return 1 @@ -811,7 +817,7 @@ class Lib9PPlugin: return [QName(f"{meth}_{msg}") for msg in self.lib9p_msgs], True return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: if "lib9p/srv.c:srv_util_pathfree" in str(call): assert isinstance(self.CONFIG_9P_SRV_MAX_DEPTH, int) if len(chain) >= self.CONFIG_9P_SRV_MAX_DEPTH and all( @@ -850,7 +856,7 @@ class LibMiscPlugin: ) -> tuple[typing.Collection[QName], bool] | None: return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: if ( len(chain) > 1 and str(chain[-1].base()) == "__assert_msg_fail" @@ -904,7 +910,7 @@ class PicoFmtPlugin: return [QName(x) for x in self.known_fct.values()], False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: if str(call.base()) in self.known_out.values(): out = "" for pcall in chain: @@ -1030,7 +1036,7 @@ class PicoSDKPlugin: return self.app_preinit_array, False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False def extra_nodes(self) -> typing.Collection[Node]: @@ -1184,7 +1190,7 @@ class TinyUSBDevicePlugin: r"^\s*#\s*define\s+(?P<k>CFG_TUD_(?:\S{3}|AUDIO|VIDEO|MIDI|VENDOR|USBTMC|DFU_RUNTIME|ECM_RNDIS))\s+(?P<v>\S+).*" ) tusb_config: dict[str, bool] = {} - with open(tusb_config_h_fname, "r") as fh: + with open(tusb_config_h_fname, "r", encoding="utf-8") as fh: in_table = False for line in fh: line = line.rstrip() @@ -1200,7 +1206,7 @@ class TinyUSBDevicePlugin: re_tud_if1 = re.compile(r"^\s*#\s*if (\S+)\s*") re_tud_if2 = re.compile(r"^\s*#\s*if (\S+)\s*\|\|\s*(\S+)\s*") re_tud_endif = re.compile(r"^\s*#\s*endif\s*") - with open(usbd_c_fname, "r") as fh: + with open(usbd_c_fname, "r", encoding="utf-8") as fh: in_table = False enabled = True for line in fh: @@ -1253,15 +1259,15 @@ class TinyUSBDevicePlugin: QName("tud_vendor_control_xfer_cb"), *sorted(self.tud_drivers["control_xfer_cb"]), ], False - elif call.startswith("driver->"): + if call.startswith("driver->"): return sorted(self.tud_drivers[call[len("driver->") :]]), False - elif call == "event.func_call.func": + if call == "event.func_call.func": # callback from usb_defer_func() return [], False return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -1328,7 +1334,7 @@ class NewlibPlugin: ) -> tuple[typing.Collection[QName], bool] | None: return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -1358,7 +1364,7 @@ class LibGCCPlugin: ) -> tuple[typing.Collection[QName], bool] | None: return None - def skip_call(self, chain: list[QName], call: QName) -> bool: + def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: return False @@ -1516,20 +1522,20 @@ def main( size: int rows: list[CrRow] = [] - main: CrRow | None = None + mainrow: CrRow | None = None for funcname, val in result.groups["Threads"].rows.items(): name = str(funcname.base()) base = val.nstatic size = base + intrstack if name in ("main", "_entry_point"): - main = CrRow(name=name, cnt=1, base=base, size=size) + mainrow = CrRow(name=name, cnt=1, base=base, size=size) else: size = next_power_of_2(size + stack_guard_size) - stack_guard_size rows.append(CrRow(name=name, cnt=val.cnt, base=base, size=size)) namelen = max(len(r.name) for r in rows) baselen = max(len(str(r.base)) for r in rows) sizesum = sum(r.cnt * (r.size + stack_guard_size) for r in rows) - sizelen = len(str(max(sizesum, main.size if main else 0))) + sizelen = len(str(max(sizesum, mainrow.size if mainrow else 0))) def print_row(comment: bool, name: str, size: int, eqn: str | None = None) -> None: prefix = "const size_t CONFIG_COROUTINE_STACK_SIZE_" @@ -1554,12 +1560,12 @@ def main( f"LM_NEXT_POWER_OF_2({str(row.base).rjust(baselen)}+{intrstack}+{stack_guard_size})-{stack_guard_size}", ) print_row(True, "TOTAL (inc. stack guard)", sizesum) - if main: + if mainrow: print_row( True, "MAIN/KERNEL", - main.size, - f" {str(main.base).rjust(baselen)}+{intrstack}", + mainrow.size, + f" {str(mainrow.base).rjust(baselen)}+{intrstack}", ) print() print("/*") @@ -1594,7 +1600,7 @@ if __name__ == "__main__": for obj_fname in obj_fnames: if re_c_obj_suffix.search(obj_fname): ci_fnames.add(re_c_obj_suffix.sub(".c.ci", obj_fname)) - with open(obj_fname + ".d", "r") as fh: + with open(obj_fname + ".d", "r", encoding="utf-8") as fh: c_fnames.update( fh.read().replace("\\\n", " ").split(":")[-1].split() ) diff --git a/gdb-helpers/libcr.py b/gdb-helpers/libcr.py index 3ffafce..f74a702 100644 --- a/gdb-helpers/libcr.py +++ b/gdb-helpers/libcr.py @@ -7,11 +7,14 @@ import contextlib import time import typing -import gdb -import gdb.unwinder +import gdb # pylint: disable=import-error +import gdb.unwinder # pylint: disable=import-error # GDB helpers ################################################################## +# https://sourceware.org/bugzilla/show_bug.cgi?id=32428 +gdb_bug_32428 = True + class _gdb_Locus(typing.Protocol): @property @@ -125,19 +128,19 @@ class CrGlobals: return if self.coroutine_running: if not self.coroutine_running.is_selected(): - if True: # https://sourceware.org/bugzilla/show_bug.cgi?id=32428 + if gdb_bug_32428: print("Must return to running coroutine before continuing.") print("Hit ^C twice then run:") print(f" cr select {self.coroutine_running.id}") while True: time.sleep(1) - assert self.coroutine_running._cont_env - gdb_longjmp(self.coroutine_running._cont_env) + assert self.coroutine_running.cont_env + gdb_longjmp(self.coroutine_running.cont_env) for cr in self.coroutines: - cr._cont_env = None + cr.cont_env = None def is_valid_cid(self, cid: int) -> bool: - return 0 < cid and cid <= len(self.coroutines) + return 0 < cid <= len(self.coroutines) @property def coroutine_running(self) -> "CrCoroutine | None": @@ -211,6 +214,8 @@ class CrBreakpoint(gdb.Breakpoint): @enabled.setter def enabled(self, value: bool) -> None: self._unwinder.enabled = value + # Use a dunder-call to avoid an infinite loop. + # pylint: disable=unnecessary-dunder-call gdb.Breakpoint.enabled.__set__(self, value) # type: ignore def stop(self) -> bool: @@ -243,12 +248,12 @@ def cr_select_top_frame() -> None: class CrCoroutine: cr_globals: CrGlobals cid: int - _cont_env: gdb_JmpBuf | None + cont_env: gdb_JmpBuf | None def __init__(self, cr_globals: CrGlobals, cid: int) -> None: self.cr_globals = cr_globals self.cid = cid - self._cont_env = None + self.cont_env = None @property def id(self) -> int: @@ -269,18 +274,18 @@ class CrCoroutine: sp = int(gdb.parse_and_eval("$sp")) lo = int(gdb.parse_and_eval(f"coroutine_table[{self.id-1}].stack")) hi = lo + int(gdb.parse_and_eval(f"coroutine_table[{self.id-1}].stack_size")) - return lo <= sp and sp < hi + return lo <= sp < hi def select(self, level: int = -1) -> None: if self.cr_globals.coroutine_selected: - self.cr_globals.coroutine_selected._cont_env = gdb_setjmp() + self.cr_globals.coroutine_selected.cont_env = gdb_setjmp() - if self._cont_env: - gdb_longjmp(self._cont_env) + if self.cont_env: + gdb_longjmp(self.cont_env) else: env: gdb_JmpBuf if self == self.cr_globals.coroutine_running: - assert False # self._cont_env should have been set + assert False # self.cont_env should have been set elif self.state == self.cr_globals.CR_RUNNING: env = self.cr_globals.readjmp("&coroutine_add_env") else: @@ -332,7 +337,7 @@ class CrListCommand(gdb.Command): def invoke(self, arg: str, from_tty: bool) -> None: argv = gdb.string_to_argv(arg) if len(argv) != 0: - raise gdb.GdbError(f"Usage: cr list") + raise gdb.GdbError("Usage: cr list") rows: list[tuple[str, str, str, str, str]] = [ ("", "Id", "Name", "State", "Frame") @@ -384,7 +389,7 @@ class CrListCommand(gdb.Command): cr_select_top_frame() full = gdb.execute("frame", from_tty=from_tty, to_string=True) gdb.execute(f"select-frame level {saved_level}") - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught full = "#0 err: " + str(e) line = full.split("\n", maxsplit=1)[0] return line.split(maxsplit=1)[1] @@ -432,23 +437,23 @@ class CrSelectCommand(gdb.Command): # Wire it all in ############################################################### -cr_globals: CrGlobals | None = None +_cr_globals: CrGlobals | None = None def cr_initialize() -> None: - global cr_globals - if cr_globals: - old = cr_globals + global _cr_globals + if _cr_globals: + old = _cr_globals new = CrGlobals() for i in range(min(len(old.coroutines), len(new.coroutines))): - new.coroutines[i]._cont_env = old.coroutines[i]._cont_env + new.coroutines[i].cont_env = old.coroutines[i].cont_env old.delete() - cr_globals = new + _cr_globals = new else: - cr_globals = CrGlobals() - CrCommand(cr_globals) - CrListCommand(cr_globals) - CrSelectCommand(cr_globals) + _cr_globals = CrGlobals() + CrCommand(_cr_globals) + CrListCommand(_cr_globals) + CrSelectCommand(_cr_globals) def cr_on_new_objfile(event: gdb.Event) -> None: @@ -460,7 +465,7 @@ def cr_on_new_objfile(event: gdb.Event) -> None: gdb.events.new_objfile.disconnect(cr_on_new_objfile) -if cr_globals: +if _cr_globals: cr_initialize() else: gdb.events.new_objfile.connect(cr_on_new_objfile) diff --git a/gdb-helpers/rp2040.py b/gdb-helpers/rp2040.py index 983e13b..45bdbc7 100644 --- a/gdb-helpers/rp2040.py +++ b/gdb-helpers/rp2040.py @@ -5,7 +5,7 @@ import typing -import gdb +import gdb # pylint: disable=import-error def read_mmreg(addr: int) -> int: @@ -110,9 +110,7 @@ class RP2040ShowInterrupts(gdb.Command): """Show the RP2040's interrupt state.""" def __init__(self) -> None: - super(RP2040ShowInterrupts, self).__init__( - "rp2040-show-interrupts", gdb.COMMAND_USER - ) + super().__init__("rp2040-show-interrupts", gdb.COMMAND_USER) def invoke(self, arg: str, from_tty: bool) -> None: self.arm_cortex_m0plus_registers() @@ -219,7 +217,7 @@ class RP2040ShowDMA(gdb.Command): """Show the RP2040's DMA control registers.""" def __init__(self) -> None: - super(RP2040ShowDMA, self).__init__("rp2040-show-dma", gdb.COMMAND_USER) + super().__init__("rp2040-show-dma", gdb.COMMAND_USER) def invoke(self, arg: str, from_tty: bool) -> None: base: int = 0x50000000 @@ -273,7 +271,7 @@ class RP2040ShowDMA(gdb.Command): return "NULL " return f"0x{val:08x}" - ret = f""" + ret = """ ╓sniff_enable ║╓bswap ║║╓irq_quiet diff --git a/lib9p/idl.gen b/lib9p/idl.gen index 53b1c60..779b6d5 100755 --- a/lib9p/idl.gen +++ b/lib9p/idl.gen @@ -12,8 +12,7 @@ import sys import typing sys.path.insert(0, os.path.normpath(os.path.join(__file__, ".."))) - -import idl +import idl # pylint: disable=wrong-import-position,import-self # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in @@ -74,13 +73,13 @@ def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" case idl.Number(): - return f"{idprefix}{typ.name}_t" + return f"{idprefix}{typ.typname}_t" case idl.Bitfield(): - return f"{idprefix}{typ.name}_t" + return f"{idprefix}{typ.typname}_t" case idl.Message(): - return f"struct {idprefix}msg_{typ.name}" + return f"struct {idprefix}msg_{typ.typname}" case idl.Struct(): - return f"struct {idprefix}{typ.name}" + return f"struct {idprefix}{typ.typname}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") @@ -93,12 +92,12 @@ def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: ret.append(tok.op) case idl.ExprLit(): ret.append(str(tok.val)) - case idl.ExprSym(name="s32_max"): + case idl.ExprSym(symname="s32_max"): ret.append("INT32_MAX") - case idl.ExprSym(name="s64_max"): + case idl.ExprSym(symname="s64_max"): ret.append("INT64_MAX") case idl.ExprSym(): - ret.append(lookup_sym(tok.name)) + ret.append(lookup_sym(tok.symname)) case _: assert False return " ".join(ret) @@ -109,7 +108,6 @@ _ifdef_stack: list[str | None] = [] def ifdef_push(n: int, _newval: str) -> str: # Grow the stack as needed - global _ifdef_stack while len(_ifdef_stack) < n: _ifdef_stack.append(None) @@ -191,14 +189,14 @@ class Path: for i, elem in enumerate(self.elems): if i > 0: ret += "." - ret += elem.name + ret += elem.membname if elem.cnt: ret += f"[{chr(ord('i')+loopdepth)}]" loopdepth += 1 return ret def __str__(self) -> str: - return self.c_str(self.root.name + "->") + return self.c_str(self.root.typname + "->") class WalkCmd(enum.Enum): @@ -243,7 +241,7 @@ def walk(typ: idl.Type, handle: WalkHandler) -> None: # get_buffer_size() ############################################################ -class BufferSize: +class BufferSize(typing.NamedTuple): min_size: int # really just here to sanity-check against typ.min_size(version) exp_size: int # "expected" or max-reasonable size max_size: int # really just here to sanity-check against typ.max_size(version) @@ -251,8 +249,19 @@ class BufferSize: max_copy_extra: str max_iov: int max_iov_extra: str - _starts_with_copy: bool - _ends_with_copy: bool + + +class TmpBufferSize: + min_size: int + exp_size: int + max_size: int + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + + tmp_starts_with_copy: bool + tmp_ends_with_copy: bool def __init__(self) -> None: self.min_size = 0 @@ -262,14 +271,14 @@ class BufferSize: self.max_copy_extra = "" self.max_iov = 0 self.max_iov_extra = "" - self._starts_with_copy = False - self._ends_with_copy = False + self.tmp_starts_with_copy = False + self.tmp_ends_with_copy = False -def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: +def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) - ret = BufferSize() + ret = TmpBufferSize() if not isinstance(typ, idl.Struct): assert typ.static_size @@ -278,8 +287,8 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: ret.max_size = typ.static_size ret.max_copy = typ.static_size ret.max_iov = 1 - ret._starts_with_copy = True - ret._ends_with_copy = True + ret.tmp_starts_with_copy = True + ret.tmp_ends_with_copy = True return ret def handle(path: Path) -> tuple[WalkCmd, None]: @@ -292,20 +301,20 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: if child.typ.static_size == 1: # SPECIAL (zerocopy) ret.max_iov += 1 # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data - ret.exp_size += 27 if child.name == "utf8" else 8192 + ret.exp_size += 27 if child.membname == "utf8" else 8192 ret.max_size += child.max_cnt - ret._ends_with_copy = False + ret.tmp_ends_with_copy = False return WalkCmd.DONT_RECURSE, None - sub = get_buffer_size(child.typ, version) + sub = _get_buffer_size(child.typ, version) ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM ret.max_size += sub.max_size * child.max_cnt - if child.name == "wname" and path.root.name in ( + if child.membname == "wname" and path.root.typname in ( "Tsread", "Tswrite", ): # SPECIAL (9P2000.e) - assert ret._ends_with_copy - assert sub._starts_with_copy - assert not sub._ends_with_copy + assert ret.tmp_ends_with_copy + assert sub.tmp_starts_with_copy + assert not sub.tmp_ends_with_copy ret.max_copy_extra = ( f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" ) @@ -315,29 +324,29 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: ret.max_iov -= 1 else: ret.max_copy += sub.max_copy * child.max_cnt - if sub.max_iov == 1 and sub._starts_with_copy: # is purely copy + if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy ret.max_iov += 1 else: # contains zero-copy segments ret.max_iov += sub.max_iov * child.max_cnt - if ret._ends_with_copy and sub._starts_with_copy: + if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: # we can merge this one ret.max_iov -= 1 if ( - sub._ends_with_copy - and sub._starts_with_copy + sub.tmp_ends_with_copy + and sub.tmp_starts_with_copy and sub.max_iov > 1 ): # we can merge these ret.max_iov -= child.max_cnt - 1 - ret._ends_with_copy = sub._ends_with_copy + ret.tmp_ends_with_copy = sub.tmp_ends_with_copy return WalkCmd.DONT_RECURSE, None - elif not isinstance(child.typ, idl.Struct): + if not isinstance(child.typ, idl.Struct): assert child.typ.static_size - if not ret._ends_with_copy: + if not ret.tmp_ends_with_copy: if ret.max_size == 0: - ret._starts_with_copy = True + ret.tmp_starts_with_copy = True ret.max_iov += 1 - ret._ends_with_copy = True + ret.tmp_ends_with_copy = True ret.min_size += child.typ.static_size ret.exp_size += child.typ.static_size ret.max_size += child.typ.static_size @@ -350,6 +359,19 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: return ret +def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: + tmp = _get_buffer_size(typ, version) + return BufferSize( + min_size=tmp.min_size, + exp_size=tmp.exp_size, + max_size=tmp.max_size, + max_copy=tmp.max_copy, + max_copy_extra=tmp.max_copy_extra, + max_iov=tmp.max_iov, + max_iov_extra=tmp.max_iov_extra, + ) + + # Generate .h ################################################################## @@ -372,7 +394,7 @@ def gen_h(versions: set[str], typs: list[idl.Type]) -> str: for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg - ret += f""" + ret += """ /* config *********************************************************************/ #include "config.h" @@ -412,13 +434,15 @@ enum {idprefix}version {{ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" - namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message)) + namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" + ret += ( + f"\t{idprefix.upper()}TYP_{msg.typname.ljust(namewidth)} = {msg.msgid},\n" + ) ret += ifdef_pop(0) ret += "};\n" @@ -469,7 +493,7 @@ enum {idprefix}version {{ match typ: case idl.Number(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" - prefix = f"{idprefix.upper()}{typ.name.upper()}_" + prefix = f"{idprefix.upper()}{typ.typname.upper()}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" @@ -481,7 +505,7 @@ enum {idprefix}version {{ if aliases := [k for k in typ.names if k not in typ.bits]: names.append("") names.extend(aliases) - prefix = f"{idprefix.upper()}{typ.name.upper()}_" + prefix = f"{idprefix.upper()}{typ.typname.upper()}_" namewidth = max(len(add_prefix(prefix, name)) for name in names) ret += "\n" @@ -527,7 +551,7 @@ enum {idprefix}version {{ if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" + ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.membname};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) @@ -545,7 +569,7 @@ enum {idprefix}version {{ for typ in typs: if not isinstance(typ, idl.Message): continue - if typ.name in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) + if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) continue max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy @@ -578,7 +602,7 @@ enum {idprefix}version {{ ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" indent = 1 if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) - typ = next(typ for typ in typs if typ.name == "Tswrite") + typ = next(typ for typ in typs if typ.typname == "Tswrite") sz = get_buffer_size(typ, "9P2000.e") match name: case "tmsg_max_iov": @@ -589,7 +613,7 @@ enum {idprefix}version {{ assert False ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" - ret += f"\t#else\n" + ret += "\t#else\n" indent += 1 ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) @@ -601,14 +625,14 @@ enum {idprefix}version {{ ret += "\n" ret += f"struct {idprefix}Tmsg_send_buf {{\n" - ret += f"\tsize_t iov_cnt;\n" + ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" ret += "};\n" ret += "\n" ret += f"struct {idprefix}Rmsg_send_buf {{\n" - ret += f"\tsize_t iov_cnt;\n" + ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" ret += "};\n" @@ -638,7 +662,7 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: """ # utilities ################################################################ - ret += f""" + ret += """ /* utilities ******************************************************************/ """ @@ -662,13 +686,13 @@ def gen_c(versions: set[str], typs: list[idl.Type]) -> str: xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: if ver == "unknown": # SPECIAL (initialization) - if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: + if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: xmsg = None else: if ver not in xmsg.in_versions: xmsg = None if xmsg: - ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n" + ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" ret += "\t},\n" ret += ifdef_pop(0) ret += "};\n" @@ -707,7 +731,7 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) # bitmasks ################################################################# - ret += f""" + ret += """ /* bitmasks *******************************************************************/ """ for typ in typs: @@ -715,7 +739,7 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ continue ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n" + ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): ret += ifdef_push(2, c_ver_ifdef({ver})) @@ -769,28 +793,32 @@ LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _val LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ + + def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: + return bool( + member.max or member.val or any(m.cnt == member for m in typ.members) + ) + for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" + ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: case idl.Number(): - ret += f"\treturn validate_{typ.prim.name}(ctx);\n" + ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" - ret += ( - f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n" - ) + ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" if typ.static_size == 1: ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" - ret += f"\tif (val & ~mask)\n" - ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' + ret += "\tif (val & ~mask)\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" case idl.Struct(): # and idl.Message() if len(typ.members) == 0: @@ -798,60 +826,51 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val ret += "}\n" continue - def should_save_value(member: idl.StructMember) -> bool: - nonlocal typ - assert isinstance(typ, idl.Struct) - return bool( - member.max - or member.val - or any(m.cnt == member for m in typ.members) - ) - # Pass 1 - declare value variables for member in typ.members: - if should_save_value(member): + if should_save_value(typ, member): ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ)} {member.name};\n" + ret += f"\t{c_typename(member.typ)} {member.membname};\n" ret += ifdef_pop(1) # Pass 2 - declare offset variables mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"): - if tok.name[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" - mark_offset.add(tok.name[1:]) + if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): + if tok.symname[1:] not in mark_offset: + ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" + mark_offset.add(tok.symname[1:]) # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || " + ret += "\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: if member.typ.static_size == 1: # SPECIAL (zerocopy) - ret += f"_validate_size_net(ctx, {member.cnt.name})" + ret += f"_validate_size_net(ctx, {member.cnt.membname})" else: - ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" - if typ.name == "s": # SPECIAL (string) - ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))" + if typ.typname == "s": # SPECIAL (string) + ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' else: - if should_save_value(member): + if should_save_value(typ, member): ret += "(" - if member.name in mark_offset: - ret += f"({{ _{member.name}_offset = ctx->net_offset; " - ret += f"validate_{member.typ.name}(ctx)" - if member.name in mark_offset: + if member.membname in mark_offset: + ret += f"({{ _{member.membname}_offset = ctx->net_offset; " + ret += f"validate_{member.typ.typname}(ctx)" + if member.membname in mark_offset: ret += "; })" - if should_save_value(member): + if should_save_value(typ, member): nbytes = member.static_size assert nbytes if nbytes == 1: - ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" else: - ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" + ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" @@ -871,14 +890,14 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val assert member.static_size nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.name}) > max) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n' + ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.name}) != exp) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n' + ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' ret += ifdef_pop(1) ret += "\t ;\n" @@ -914,12 +933,12 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" match typ: case idl.Number(): - ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" @@ -939,21 +958,17 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "{\n" ret += prefix if member.typ.static_size == 1: # SPECIAL (string, zerocopy) - ret += f"out->{member.name} = (char *)&ctx->net_bytes[ctx->net_offset];\n" - ret += ( - f"{prefix}ctx->net_offset += out->{member.cnt.name};\n" - ) + ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" + ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" else: - ret += f"out->{member.name} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" - ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" - ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" + ret += f"out->{member.membname} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" if member.in_versions != typ.in_versions: ret += "\t}\n" else: - ret += ( - f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" - ) + ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" ret += ifdef_pop(1) ret += "}\n" ret += ifdef_pop(0) @@ -1140,7 +1155,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: def ret(path: Path) -> WalkCmd: - if len(path.elems) == 1 and path.elems[0].name == name: + if len(path.elems) == 1 and path.elems[0].membname == name: return WalkCmd.ABORT return WalkCmd.KEEP_GOING @@ -1148,13 +1163,13 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o for typ in typs: if not ( - isinstance(typ, idl.Message) or typ.name == "stat" + isinstance(typ, idl.Message) or typ.typname == "stat" ): # SPECIAL (include stat) continue assert isinstance(typ, idl.Struct) ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.name}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" + ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) @@ -1170,8 +1185,8 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ) ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" if isinstance(typ, idl.Message): # SPECIAL (disable for stat) - ret += f'\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' - ret += f'\t\t\t"{typ.name}",\n' + ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' + ret += f'\t\t\t"{typ.typname}",\n' ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' ret += "\t\t\tctx->ctx->max_msg_size);\n" ret += "\t\treturn true;\n" @@ -1209,12 +1224,12 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o for tok in member.val.tokens: if not isinstance(tok, idl.ExprSym): continue - if tok.name == "end" or tok.name.startswith("&"): - if tok.name not in offsets: - offsets.append(tok.name) + if tok.symname == "end" or tok.symname.startswith("&"): + if tok.symname not in offsets: + offsets.append(tok.symname) for name in offsets: name_prefix = "offsetof_" + "".join( - m.name + "_" for m in path.elems + m.membname + "_" for m in path.elems ) if name == "end": if not path.elems: @@ -1251,7 +1266,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o if child.cnt: cnt_path = path.parent().add(child.cnt) if child.typ.static_size == 1: # SPECIAL (zerocopy) - if path.root.name == "stat": # SPECIAL (stat) + if path.root.typname == "stat": # SPECIAL (stat) ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" else: ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" @@ -1268,7 +1283,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o sym = sym[1:] return ( "offsetof_" - + "".join(m.name + "_" for m in path.elems[:-1]) + + "".join(m.membname + "_" for m in path.elems[:-1]) + sym ) @@ -1276,11 +1291,14 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o else: val = path.c_str("val->") if isinstance(child.typ, idl.Bitfield): - val += f" & {child.typ.name}_masks[ctx->ctx->version]" + val += f" & {child.typ.typname}_masks[ctx->ctx->version]" ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" return WalkCmd.KEEP_GOING, pop walk(typ, handle) + del handle + del stack + del max_size ret += "\treturn false;\n" ret += "}\n" @@ -1293,7 +1311,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "\n" ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" - rerror = next(typ for typ in typs if typ.name == "Rerror") + rerror = next(typ for typ in typs if typ.typname == "Rerror") ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) for ver in sorted(versions): ret += ifdef_push(1, c_ver_ifdef({ver})) @@ -1342,9 +1360,7 @@ LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idpref # Main ######################################################################### -if __name__ == "__main__": - import sys - +def main() -> None: if typing.TYPE_CHECKING: class ANSIColors: @@ -1375,7 +1391,13 @@ if __name__ == "__main__": sys.exit(2) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) - with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: + with open( + os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8" + ) as fh: fh.write(gen_h(versions, typs)) - with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: + with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: fh.write(gen_c(versions, typs)) + + +if __name__ == "__main__": + main() diff --git a/lib9p/idl/__init__.py b/lib9p/idl/__init__.py index a01c38f..042a438 100644 --- a/lib9p/idl/__init__.py +++ b/lib9p/idl/__init__.py @@ -6,8 +6,9 @@ import enum import os.path import re -from typing import Callable, Literal, TypeVar, cast +import typing +# pylint: disable=unused-variable __all__ = [ # entrypoint "Parser", @@ -36,7 +37,7 @@ class Primitive(enum.Enum): return set() @property - def name(self) -> str: + def typname(self) -> str: return str(self.value) @property @@ -51,7 +52,7 @@ class Primitive(enum.Enum): class Number: - name: str + typname: str in_versions: set[str] prim: Primitive @@ -74,7 +75,7 @@ class Number: class BitfieldVal: - name: str + bitname: str in_versions: set[str] val: str @@ -84,7 +85,7 @@ class BitfieldVal: class Bitfield: - name: str + typname: str in_versions: set[str] prim: Primitive @@ -130,16 +131,16 @@ class ExprLit: class ExprSym: - name: str + symname: str def __init__(self, name: str) -> None: - self.name = name + self.symname = name class ExprOp: - op: Literal["-", "+"] + op: typing.Literal["-", "+"] - def __init__(self, op: Literal["-", "+"]) -> None: + def __init__(self, op: typing.Literal["-", "+"]) -> None: self.op = op @@ -156,7 +157,7 @@ class Expr: class StructMember: # from left-to-right when parsing cnt: "StructMember | None" = None - name: str + membname: str typ: "Type" max: Expr val: Expr @@ -168,10 +169,12 @@ class StructMember: assert self.cnt if not isinstance(self.cnt.typ, Primitive): raise ValueError( - f"list count must be an integer type: {repr(self.cnt.name)}" + f"list count must be an integer type: {repr(self.cnt.membname)}" ) if self.cnt.val: # TODO: allow this? - raise ValueError(f"list count may not have ,val=: {repr(self.cnt.name)}") + raise ValueError( + f"list count may not have ,val=: {repr(self.cnt.membname)}" + ) return 0 @property @@ -179,26 +182,28 @@ class StructMember: assert self.cnt if not isinstance(self.cnt.typ, Primitive): raise ValueError( - f"list count must be an integer type: {repr(self.cnt.name)}" + f"list count must be an integer type: {repr(self.cnt.membname)}" ) if self.cnt.val: # TODO: allow this? - raise ValueError(f"list count may not have ,val=: {repr(self.cnt.name)}") + raise ValueError( + f"list count may not have ,val=: {repr(self.cnt.membname)}" + ) if self.cnt.max: # TODO: be more flexible? if len(self.cnt.max.tokens) != 1: raise ValueError( - f"list count ,max= may only have 1 token: {repr(self.cnt.name)}" + f"list count ,max= may only have 1 token: {repr(self.cnt.membname)}" ) match tok := self.cnt.max.tokens[0]: case ExprLit(): return tok.val - case ExprSym(name="s32_max"): + case ExprSym(symname="s32_max"): return (1 << 31) - 1 - case ExprSym(name="s64_max"): + case ExprSym(symname="s64_max"): return (1 << 63) - 1 case _: raise ValueError( - f'list count ,max= only allows literal, "s32_max", and "s64_max" tokens: {repr(self.cnt.name)}' + f'list count ,max= only allows literal, "s32_max", and "s64_max" tokens: {repr(self.cnt.membname)}' ) return (1 << (self.cnt.typ.value * 8)) - 1 @@ -218,7 +223,7 @@ class StructMember: class Struct: - name: str + typname: str in_versions: set[str] members: list[StructMember] @@ -257,7 +262,7 @@ class Message(Struct): @property def msgid(self) -> int: assert len(self.members) >= 3 - assert self.members[1].name == "typ" + assert self.members[1].membname == "typ" assert self.members[1].static_size == 1 assert self.members[1].val assert len(self.members[1].val.tokens) == 1 @@ -266,7 +271,7 @@ class Message(Struct): type Type = Primitive | Number | Bitfield | Struct | Message -T = TypeVar("T", Number, Bitfield, Struct, Message) +T = typing.TypeVar("T", Number, Bitfield, Struct, Message) # Parse ######################################################################## @@ -294,7 +299,7 @@ def parse_numspec(ver: str, n: Number, spec: str) -> None: name = m.group("name") val = m.group("val") if name in n.vals: - raise ValueError(f"{n.name}: name {repr(name)} already assigned") + raise ValueError(f"{n.typname}: name {repr(name)} already assigned") n.vals[name] = val else: raise SyntaxError(f"invalid num spec {repr(spec)}") @@ -310,39 +315,39 @@ def parse_bitspec(ver: str, bf: Bitfield, spec: str) -> None: name = m.group("name") val = BitfieldVal() - val.name = name + val.bitname = name val.val = f"1<<{bit}" val.in_versions.add(ver) if bit < 0 or bit >= len(bf.bits): - raise ValueError(f"{bf.name}: bit {bit} is out-of-bounds") + raise ValueError(f"{bf.typname}: bit {bit} is out-of-bounds") if bf.bits[bit]: - raise ValueError(f"{bf.name}: bit {bit} already assigned") - bf.bits[bit] = val.name + raise ValueError(f"{bf.typname}: bit {bit} already assigned") + bf.bits[bit] = val.bitname elif m := re.fullmatch(re_bitspec_alias, spec): name = m.group("name") valstr = m.group("val") val = BitfieldVal() - val.name = name + val.bitname = name val.val = valstr val.in_versions.add(ver) else: raise SyntaxError(f"invalid bitfield spec {repr(spec)}") - if val.name in bf.names: - raise ValueError(f"{bf.name}: name {val.name} already assigned") - bf.names[val.name] = val + if val.bitname in bf.names: + raise ValueError(f"{bf.typname}: name {val.bitname} already assigned") + bf.names[val.bitname] = val def parse_expr(expr: str) -> Expr: assert re.fullmatch(re_expr, expr) ret = Expr() for tok in re.split("([-+])", expr): - if tok == "-" or tok == "+": + if tok in ("-", "+"): # I, for the life of me, do not understand why I need this - # cast() to keep mypy happy. - ret.tokens += [ExprOp(cast(Literal["-", "+"], tok))] + # typing.cast() to keep mypy happy. + ret.tokens += [ExprOp(typing.cast(typing.Literal["-", "+"], tok))] elif re.fullmatch("[0-9]+", tok): ret.tokens += [ExprLit(int(tok))] else: @@ -359,16 +364,16 @@ def parse_members(ver: str, env: dict[str, Type], struct: Struct, specs: str) -> member = StructMember() member.in_versions = {ver} - member.name = m.group("name") - if any(x.name == member.name for x in struct.members): - raise ValueError(f"duplicate member name {repr(member.name)}") + member.membname = m.group("name") + if any(x.membname == member.membname for x in struct.members): + raise ValueError(f"duplicate member name {member.membname!r}") if m.group("typ") not in env: raise NameError(f"Unknown type {repr(m.group('typ'))}") member.typ = env[m.group("typ")] if cnt := m.group("cnt"): - if len(struct.members) == 0 or struct.members[-1].name != cnt: + if len(struct.members) == 0 or struct.members[-1].membname != cnt: raise ValueError(f"list count must be previous item: {repr(cnt)}") cnt_mem = struct.members[-1] member.cnt = cnt_mem @@ -412,7 +417,7 @@ re_line_cont = f"\\s+{re_string('specs')}" # could be bitfield/struct/msg def parse_file( - filename: str, get_include: Callable[[str], tuple[str, list[Type]]] + filename: str, get_include: typing.Callable[[str], tuple[str, list[Type]]] ) -> tuple[str, list[Type]]: version: str | None = None env: dict[str, Type] = { @@ -428,10 +433,10 @@ def parse_file( raise NameError(f"Unknown type {repr(name)}") ret = env[name] if (not isinstance(ret, tc)) or (ret.__class__.__name__ != tc.__name__): - raise NameError(f"Type {repr(ret.name)} is not a {tc.__name__}") + raise NameError(f"Type {repr(ret.typname)} is not a {tc.__name__}") return ret - with open(filename, "r") as fh: + with open(filename, "r", encoding="utf-8") as fh: prev: Type | None = None for lineno, line in enumerate(fh): try: @@ -452,7 +457,7 @@ def parse_file( symname = symname.strip() found = False for typ in other_typs: - if typ.name == symname or symname == "*": + if symname in (typ.typname, "*"): found = True match typ: case Primitive(): @@ -469,31 +474,31 @@ def parse_file( for member in typ.members: if other_version in member.in_versions: member.in_versions.add(version) - if typ.name in env and env[typ.name] != typ: + if typ.typname in env and env[typ.typname] != typ: raise ValueError( - f"duplicate type name {repr(typ.name)}" + f"duplicate type name {typ.typname!r}" ) - env[typ.name] = typ + env[typ.typname] = typ if symname != "*" and not found: raise ValueError( f"import: {m.group('file')}: no symbol {repr(symname)}" ) elif m := re.fullmatch(re_line_num, line): num = Number() - num.name = m.group("name") + num.typname = m.group("name") num.in_versions.add(version) prim = env[m.group("prim")] assert isinstance(prim, Primitive) num.prim = prim - if num.name in env: - raise ValueError(f"duplicate type name {repr(num.name)}") - env[num.name] = num + if num.typname in env: + raise ValueError(f"duplicate type name {num.typname!r}") + env[num.typname] = num prev = num elif m := re.fullmatch(re_line_bitfield, line): bf = Bitfield() - bf.name = m.group("name") + bf.typname = m.group("name") bf.in_versions.add(version) prim = env[m.group("prim")] @@ -502,9 +507,9 @@ def parse_file( bf.bits = (prim.static_size * 8) * [""] - if bf.name in env: - raise ValueError(f"duplicate type name {repr(bf.name)}") - env[bf.name] = bf + if bf.typname in env: + raise ValueError(f"duplicate type name {bf.typname!r}") + env[bf.typname] = bf prev = bf elif m := re.fullmatch(re_line_bitfield_, line): bf = get_type(m.group("name"), Bitfield) @@ -515,16 +520,16 @@ def parse_file( match m.group("op"): case "=": struct = Struct() - struct.name = m.group("name") + struct.typname = m.group("name") struct.in_versions.add(version) struct.members = [] parse_members(version, env, struct, m.group("members")) - if struct.name in env: + if struct.typname in env: raise ValueError( - f"duplicate type name {repr(struct.name)}" + f"duplicate type name {struct.typname!r}" ) - env[struct.name] = struct + env[struct.typname] = struct prev = struct case "+=": struct = get_type(m.group("name"), Struct) @@ -535,16 +540,14 @@ def parse_file( match m.group("op"): case "=": msg = Message() - msg.name = m.group("name") + msg.typname = m.group("name") msg.in_versions.add(version) msg.members = [] parse_members(version, env, msg, m.group("members")) - if msg.name in env: - raise ValueError( - f"duplicate type name {repr(msg.name)}" - ) - env[msg.name] = msg + if msg.typname in env: + raise ValueError(f"duplicate type name {msg.typname!r}") + env[msg.typname] = msg prev = msg case "+=": msg = get_type(m.group("name"), Message) @@ -577,19 +580,24 @@ def parse_file( typs: list[Type] = [x for x in env.values() if not isinstance(x, Primitive)] for typ in [typ for typ in typs if isinstance(typ, Struct)]: - valid_syms = ["end", "s32_max", "s64_max", *["&" + m.name for m in typ.members]] + valid_syms = [ + "end", + "s32_max", + "s64_max", + *["&" + m.membname for m in typ.members], + ] for member in typ.members: if ( not isinstance(member.typ, Primitive) and member.typ.in_versions < member.in_versions ): raise ValueError( - f"{typ.name}.{member.name}: type {member.typ.name} does not exist in {member.in_versions.difference(member.typ.in_versions)}" + f"{typ.typname}.{member.membname}: type {member.typ.typname} does not exist in {member.in_versions.difference(member.typ.in_versions)}" ) for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, ExprSym) and tok.name not in valid_syms: + if isinstance(tok, ExprSym) and tok.symname not in valid_syms: raise ValueError( - f"{typ.name}.{member.name}: invalid sym: {tok.name}" + f"{typ.typname}.{member.membname}: invalid sym: {tok.symname}" ) return version, typs @@ -619,11 +627,11 @@ class Parser: raise ValueError(f"duplicate protocol version {repr(version)}") ret_versions.add(version) for typ in typs: - if typ.name in ret_typs: - if typ != ret_typs[typ.name]: - raise ValueError(f"duplicate type name {repr(typ.name)}") + if typ.typname in ret_typs: + if typ != ret_typs[typ.typname]: + raise ValueError(f"duplicate type name {repr(typ.typname)}") else: - ret_typs[typ.name] = typ + ret_typs[typ.typname] = typ msgids: set[int] = set() for typ in ret_typs.values(): if isinstance(typ, Message): diff --git a/lib9p/include/lib9p/linux-errno.h.gen b/lib9p/include/lib9p/linux-errno.h.gen index 8f4e0c8..2c736a2 100755 --- a/lib9p/include/lib9p/linux-errno.h.gen +++ b/lib9p/include/lib9p/linux-errno.h.gen @@ -1,7 +1,7 @@ #!/usr/bin/env python # lib9p/linux-errno.h.gen - Generate a C header from a list of errno numbers # -# Copyright (C) 2024 Luke T. Shumaker <lukeshu@lukeshu.com> +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> # SPDX-License-Identifier: AGPL-3.0-or-later import sys @@ -13,7 +13,7 @@ def print_errnos() -> None: ) errnos: dict[str, tuple[int, str]] = {} for txtlist in sys.argv[1:]: - with open(txtlist, "r") as fh: + with open(txtlist, "r", encoding="utf-8") as fh: for line in fh: if line.startswith("#"): print(f"/* {line[1:].strip()} */") @@ -26,12 +26,10 @@ def print_errnos() -> None: print("#ifndef _LIB9P_LINUX_ERRNO_H_") print("#define _LIB9P_LINUX_ERRNO_H_") print() - namelen = max(len(name) for name in errnos.keys()) + namelen = max(len(name) for name in errnos) numlen = max(len(str(num)) for (num, desc) in errnos.values()) - for name in errnos: - print( - f"#define LINUX_{name.ljust(namelen)} {str(errnos[name][0]).rjust(numlen)} /* {errnos[name][1]} */" - ) + for name, [num, msg] in errnos.items(): + print(f"#define LINUX_{name:<{namelen}} {num:>{numlen}} /* {msg} */") print() print("#endif /* _LIB9P_LINUX_ERRNO_H_ */") |