diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-30 02:02:45 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-31 03:51:06 -0600 |
commit | 8216970914131202c0f81a5d83d57a779ed8ca0f (patch) | |
tree | 1f21615a9f8eea2803651cfa9edd7095b704f7ae | |
parent | 32e35da60395d1279b8732df998629af79c591b5 (diff) |
measurestack: Try to reduce how often we call str()
-rw-r--r-- | build-aux/measurestack/__init__.py | 174 |
1 files changed, 93 insertions, 81 deletions
diff --git a/build-aux/measurestack/__init__.py b/build-aux/measurestack/__init__.py index b934f40..3f59bd6 100644 --- a/build-aux/measurestack/__init__.py +++ b/build-aux/measurestack/__init__.py @@ -165,6 +165,9 @@ class BaseName: def __hash__(self) -> int: return hash(self._content) + def as_qname(self) -> "QName": + return QName(self._content) + class QName: _content: str @@ -194,7 +197,7 @@ class QName: return hash(self._content) def base(self) -> BaseName: - return BaseName(str(self).rsplit(":", 1)[-1].split(".", 1)[0]) + return BaseName(self._content.rsplit(":", 1)[-1].split(".", 1)[0]) class Node: @@ -327,7 +330,7 @@ def analyze( raise ValueError(f"incomplete edge: {elem.attrs!r}") if caller not in graph: raise ValueError(f"unknown caller: {caller}") - if str(callee) == "__indirect_call": + if callee == QName("__indirect_call"): callees, missing_ok = app.indirect_callees(elem) for callee in callees: if callee not in graph[caller].calls: @@ -364,15 +367,14 @@ def analyze( funcname = QName(str(funcname)[len("__real_") :]) # Usual case - if QName(str(funcname)) in graph: - return QName(str(funcname)) + if funcname in graph: + return funcname # Handle `__weak` functions - if ( - ":" not in str(funcname) - and len(qualified.get(BaseName(str(funcname)), set())) == 1 - ): - return sorted(qualified[BaseName(str(funcname))])[0] + if ":" not in str(funcname): + qnames = qualified.get(BaseName(str(funcname)), set()) + if len(qnames) == 1: + return sorted(qnames)[0] return None @@ -387,7 +389,7 @@ def analyze( nonlocal track_inclusion funcname = resolve_funcname(orig_funcname) if not funcname: - if chain and app.skip_call(chain, QName(str(orig_funcname))): + if chain and app.skip_call(chain, orig_funcname): if dbg: print(f"//dbg: {'- '*len(chain)}{orig_funcname}\tskip missing") return 0 @@ -476,7 +478,7 @@ class Plugin(typing.Protocol): # called, but are included in the binary anyway. This may because # it is an unused method in a used vtable. This may be because it # is an atexit() callback (we never exit). - def extra_includes(self) -> typing.Collection[str]: ... + def extra_includes(self) -> typing.Collection[BaseName]: ... def extra_nodes(self) -> typing.Collection[Node]: ... def indirect_callees( @@ -534,7 +536,7 @@ class CmdPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -607,7 +609,7 @@ class LibObjPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -641,18 +643,18 @@ class LibHWPlugin: self.libobj = libobj def is_intrhandler(self, name: QName) -> bool: - return str(name.base()) in [ - "rp2040_hwtimer_intrhandler", - "hostclock_handle_sig_alarm", - "hostnet_handle_sig_io", - "gpioirq_handler", - "dmairq_handler", + return name.base() in [ + BaseName("rp2040_hwtimer_intrhandler"), + BaseName("hostclock_handle_sig_alarm"), + BaseName("hostnet_handle_sig_io"), + BaseName("gpioirq_handler"), + BaseName("dmairq_handler"), ] def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -663,14 +665,14 @@ class LibHWPlugin: ) -> tuple[typing.Collection[QName], bool] | None: if "/3rd-party/" in loc: return None - for fn in ( + for fn in [ "io_readv", "io_writev", "io_close", "io_close_read", "io_close_write", "io_readwritev", - ): + ]: if f"{fn}(" in line: return self.libobj.indirect_callees(loc, f"LO_CALL(x, {fn[3:]})") if "io_read(" in line: @@ -703,12 +705,14 @@ class LibHWPlugin: class LibCRPlugin: def is_intrhandler(self, name: QName) -> bool: - return str(name.base()) in ("_cr_gdb_intrhandler",) + return name.base() in [ + BaseName("_cr_gdb_intrhandler"), + ] def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -730,7 +734,7 @@ class LibCRIPCPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -838,7 +842,7 @@ class Lib9PPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -891,7 +895,7 @@ class LibMiscPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -905,29 +909,29 @@ class LibMiscPlugin: def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: if ( len(chain) > 1 - and str(chain[-1].base()) == "__assert_msg_fail" - and str(call.base()) == "__lm_printf" - and any(str(c.base()) == "__assert_msg_fail" for c in chain[:-1]) + and chain[-1].base() == BaseName("__assert_msg_fail") + and call.base() == BaseName("__lm_printf") + and any(c.base() == BaseName("__assert_msg_fail") for c in chain[:-1]) ): return True return False class PicoFmtPlugin: - known_out: dict[str, str] - known_fct: dict[str, str] + known_out: dict[BaseName, BaseName] + known_fct: dict[BaseName, BaseName] def __init__(self) -> None: self.known_out = { - "": "_out_null", # XXX - "__wrap_sprintf": "_out_buffer", - "__wrap_snprintf": "_out_buffer", - "__wrap_vsnprintf": "_out_buffer", - "vfctprintf": "_out_fct", + BaseName(""): BaseName("_out_null"), # XXX + BaseName("__wrap_sprintf"): BaseName("_out_buffer"), + BaseName("__wrap_snprintf"): BaseName("_out_buffer"), + BaseName("__wrap_vsnprintf"): BaseName("_out_buffer"), + BaseName("vfctprintf"): BaseName("_out_fct"), } self.known_fct = { - "stdio_vprintf": "stdio_buffered_printer", - "__wrap_vprintf": "stdio_buffered_printer", + BaseName("stdio_vprintf"): BaseName("stdio_buffered_printer"), + BaseName("__wrap_vprintf"): BaseName("stdio_buffered_printer"), } def is_intrhandler(self, name: QName) -> bool: @@ -936,7 +940,7 @@ class PicoFmtPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -951,28 +955,28 @@ class PicoFmtPlugin: m = re_call_other.fullmatch(line) call: str | None = m.group("func") if m else None if call == "out": - return [QName(x) for x in self.known_out.values()], False + return [x.as_qname() for x in self.known_out.values()], False if "->fct" in line: - return [QName(x) for x in self.known_fct.values()], False + return [x.as_qname() for x in self.known_fct.values()], False return None def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool: - if str(call.base()) in self.known_out.values(): - out = "" + if call.base() in self.known_out.values(): + out: BaseName | None = None for pcall in chain: - if str(pcall.base()) in self.known_out: - out = self.known_out[str(pcall.base())] - if ( - out == "_out_buffer" and str(call.base()) == "_out_null" + if pcall.base() in self.known_out: + out = self.known_out[pcall.base()] + if out == BaseName("_out_buffer") and call.base() == BaseName( + "_out_null" ): # XXX: Gross hack - out = "_out_null" - return str(call.base()) != out - if str(call.base()) in self.known_fct.values(): - fct = "" + out = BaseName("_out_null") + return call.base() != out + if call.base() in self.known_fct.values(): + fct: BaseName | None = None for pcall in chain: - if str(pcall.base()) in self.known_fct: - fct = self.known_fct[str(pcall.base())] - return str(call.base()) != fct + if pcall.base() in self.known_fct: + fct = self.known_fct[pcall.base()] + return call.base() != fct return False @@ -1019,20 +1023,20 @@ class PicoSDKPlugin: ] def is_intrhandler(self, name: QName) -> bool: - return str(name.base()) in [ - "isr_invalid", - "isr_nmi", - "isr_hardfault", - "isr_svcall", - "isr_pendsv", - "isr_systick", - *[f"isr_irq{n}" for n in range(32)], + return name.base() in [ + BaseName("isr_invalid"), + BaseName("isr_nmi"), + BaseName("isr_hardfault"), + BaseName("isr_svcall"), + BaseName("isr_pendsv"), + BaseName("isr_systick"), + *[BaseName(f"isr_irq{n}") for n in range(32)], ] def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def indirect_callees( @@ -1285,7 +1289,7 @@ class TinyUSBDevicePlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -1324,10 +1328,10 @@ class NewlibPlugin: def init_array(self) -> typing.Collection[QName]: return [QName("register_fini")] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [ # register_fini() calls atexit(__libc_fini_array) - "__libc_fini_array", + BaseName("__libc_fini_array"), ] def extra_nodes(self) -> typing.Collection[Node]: @@ -1391,7 +1395,7 @@ class LibGCCPlugin: def init_array(self) -> typing.Collection[QName]: return [] - def extra_includes(self) -> typing.Collection[str]: + def extra_includes(self) -> typing.Collection[BaseName]: return [] def extra_nodes(self) -> typing.Collection[Node]: @@ -1429,14 +1433,16 @@ def main( lib9p_plugin = Lib9PPlugin(arg_base_dir, arg_c_fnames) def sbc_is_thread(name: QName) -> int: - if str(name).endswith("_cr") and str(name.base()) != "lib9p_srv_read_cr": + if str(name).endswith("_cr") and name.base() != BaseName("lib9p_srv_read_cr"): if "9p" in str(name.base()) or "lib9p/tests/test_server/main.c:" in str( name ): return lib9p_plugin.thread_count(name) return 1 - if str(name.base()) == ( - "_entry_point" if arg_pico_platform == "rp2040" else "main" + if name.base() == ( + BaseName("_entry_point") + if arg_pico_platform == "rp2040" + else BaseName("main") ): return 1 return 0 @@ -1485,27 +1491,33 @@ def main( return 0, False def misc_filter(name: QName) -> tuple[int, bool]: - if str(name.base()) in ["__lm_printf", "__assert_msg_fail"]: + if name.base() in [ + BaseName("__lm_printf"), + BaseName("__assert_msg_fail"), + ]: return 1, False return 0, False - extra_includes: list[str] = [] + extra_includes: list[BaseName] = [] for plugin in plugins: extra_includes.extend(plugin.extra_includes()) def extra_filter(name: QName) -> tuple[int, bool]: nonlocal extra_includes - if str(name.base()) in extra_includes: + if name.base() in extra_includes: return 1, True return 0, False - def location_xform(loc: str) -> str: + def _str_location_xform(loc: str) -> str: if not loc.startswith("/"): return loc parts = loc.split(":", 1) parts[0] = "./" + os.path.relpath(parts[0], arg_base_dir) return ":".join(parts) + def location_xform(_loc: QName) -> str: + return _str_location_xform(str(_loc)) + result = analyze( ci_fnames=arg_ci_fnames, app_func_filters={ @@ -1514,7 +1526,7 @@ def main( "Misc": misc_filter, "Extra": extra_filter, }, - app=PluginApplication(location_xform, plugins), + app=PluginApplication(_str_location_xform, plugins), cfg_max_call_depth=100, ) @@ -1526,7 +1538,7 @@ def main( # Figure sizes. namelen = max( - [len(location_xform(str(k))) for k in grp.rows.keys()] + [len(grp_name) + 4] + [len(location_xform(k)) for k in grp.rows.keys()] + [len(grp_name) + 4] ) numlen = len(str(nsum)) sep1 = ("=" * namelen) + " " + "=" * numlen @@ -1535,7 +1547,7 @@ def main( # Print. print("= " + grp_name + " " + sep1[len(grp_name) + 3 :]) for qname, val in sorted(grp.rows.items()): - name = location_xform(str(qname)) + name = location_xform(qname) if val.nstatic == 0: continue print( @@ -1573,7 +1585,7 @@ def main( name = str(funcname.base()) base = val.nstatic size = base + intrstack - if name in ("main", "_entry_point"): + if name in ["main", "_entry_point"]: mainrow = CrRow(name=name, cnt=1, base=base, size=size) else: size = next_power_of_2(size + stack_guard_size) - stack_guard_size @@ -1618,9 +1630,9 @@ def main( print_group("Misc") for funcname in sorted(result.missing): - print(f"warning: missing: {location_xform(str(funcname))}") + print(f"warning: missing: {location_xform(funcname)}") for funcname in sorted(result.dynamic): - print(f"warning: dynamic-stack-usage: {location_xform(str(funcname))}") + print(f"warning: dynamic-stack-usage: {location_xform(funcname)}") print("*/") print("") |