summaryrefslogtreecommitdiff
path: root/build-aux/measurestack/__init__.py
diff options
context:
space:
mode:
authorLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-30 02:02:45 -0600
committerLuke T. Shumaker <lukeshu@lukeshu.com>2025-03-31 03:51:06 -0600
commit8216970914131202c0f81a5d83d57a779ed8ca0f (patch)
tree1f21615a9f8eea2803651cfa9edd7095b704f7ae /build-aux/measurestack/__init__.py
parent32e35da60395d1279b8732df998629af79c591b5 (diff)
measurestack: Try to reduce how often we call str()
Diffstat (limited to 'build-aux/measurestack/__init__.py')
-rw-r--r--build-aux/measurestack/__init__.py174
1 files changed, 93 insertions, 81 deletions
diff --git a/build-aux/measurestack/__init__.py b/build-aux/measurestack/__init__.py
index b934f40..3f59bd6 100644
--- a/build-aux/measurestack/__init__.py
+++ b/build-aux/measurestack/__init__.py
@@ -165,6 +165,9 @@ class BaseName:
def __hash__(self) -> int:
return hash(self._content)
+ def as_qname(self) -> "QName":
+ return QName(self._content)
+
class QName:
_content: str
@@ -194,7 +197,7 @@ class QName:
return hash(self._content)
def base(self) -> BaseName:
- return BaseName(str(self).rsplit(":", 1)[-1].split(".", 1)[0])
+ return BaseName(self._content.rsplit(":", 1)[-1].split(".", 1)[0])
class Node:
@@ -327,7 +330,7 @@ def analyze(
raise ValueError(f"incomplete edge: {elem.attrs!r}")
if caller not in graph:
raise ValueError(f"unknown caller: {caller}")
- if str(callee) == "__indirect_call":
+ if callee == QName("__indirect_call"):
callees, missing_ok = app.indirect_callees(elem)
for callee in callees:
if callee not in graph[caller].calls:
@@ -364,15 +367,14 @@ def analyze(
funcname = QName(str(funcname)[len("__real_") :])
# Usual case
- if QName(str(funcname)) in graph:
- return QName(str(funcname))
+ if funcname in graph:
+ return funcname
# Handle `__weak` functions
- if (
- ":" not in str(funcname)
- and len(qualified.get(BaseName(str(funcname)), set())) == 1
- ):
- return sorted(qualified[BaseName(str(funcname))])[0]
+ if ":" not in str(funcname):
+ qnames = qualified.get(BaseName(str(funcname)), set())
+ if len(qnames) == 1:
+ return sorted(qnames)[0]
return None
@@ -387,7 +389,7 @@ def analyze(
nonlocal track_inclusion
funcname = resolve_funcname(orig_funcname)
if not funcname:
- if chain and app.skip_call(chain, QName(str(orig_funcname))):
+ if chain and app.skip_call(chain, orig_funcname):
if dbg:
print(f"//dbg: {'- '*len(chain)}{orig_funcname}\tskip missing")
return 0
@@ -476,7 +478,7 @@ class Plugin(typing.Protocol):
# called, but are included in the binary anyway. This may because
# it is an unused method in a used vtable. This may be because it
# is an atexit() callback (we never exit).
- def extra_includes(self) -> typing.Collection[str]: ...
+ def extra_includes(self) -> typing.Collection[BaseName]: ...
def extra_nodes(self) -> typing.Collection[Node]: ...
def indirect_callees(
@@ -534,7 +536,7 @@ class CmdPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -607,7 +609,7 @@ class LibObjPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -641,18 +643,18 @@ class LibHWPlugin:
self.libobj = libobj
def is_intrhandler(self, name: QName) -> bool:
- return str(name.base()) in [
- "rp2040_hwtimer_intrhandler",
- "hostclock_handle_sig_alarm",
- "hostnet_handle_sig_io",
- "gpioirq_handler",
- "dmairq_handler",
+ return name.base() in [
+ BaseName("rp2040_hwtimer_intrhandler"),
+ BaseName("hostclock_handle_sig_alarm"),
+ BaseName("hostnet_handle_sig_io"),
+ BaseName("gpioirq_handler"),
+ BaseName("dmairq_handler"),
]
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -663,14 +665,14 @@ class LibHWPlugin:
) -> tuple[typing.Collection[QName], bool] | None:
if "/3rd-party/" in loc:
return None
- for fn in (
+ for fn in [
"io_readv",
"io_writev",
"io_close",
"io_close_read",
"io_close_write",
"io_readwritev",
- ):
+ ]:
if f"{fn}(" in line:
return self.libobj.indirect_callees(loc, f"LO_CALL(x, {fn[3:]})")
if "io_read(" in line:
@@ -703,12 +705,14 @@ class LibHWPlugin:
class LibCRPlugin:
def is_intrhandler(self, name: QName) -> bool:
- return str(name.base()) in ("_cr_gdb_intrhandler",)
+ return name.base() in [
+ BaseName("_cr_gdb_intrhandler"),
+ ]
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -730,7 +734,7 @@ class LibCRIPCPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -838,7 +842,7 @@ class Lib9PPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -891,7 +895,7 @@ class LibMiscPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -905,29 +909,29 @@ class LibMiscPlugin:
def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool:
if (
len(chain) > 1
- and str(chain[-1].base()) == "__assert_msg_fail"
- and str(call.base()) == "__lm_printf"
- and any(str(c.base()) == "__assert_msg_fail" for c in chain[:-1])
+ and chain[-1].base() == BaseName("__assert_msg_fail")
+ and call.base() == BaseName("__lm_printf")
+ and any(c.base() == BaseName("__assert_msg_fail") for c in chain[:-1])
):
return True
return False
class PicoFmtPlugin:
- known_out: dict[str, str]
- known_fct: dict[str, str]
+ known_out: dict[BaseName, BaseName]
+ known_fct: dict[BaseName, BaseName]
def __init__(self) -> None:
self.known_out = {
- "": "_out_null", # XXX
- "__wrap_sprintf": "_out_buffer",
- "__wrap_snprintf": "_out_buffer",
- "__wrap_vsnprintf": "_out_buffer",
- "vfctprintf": "_out_fct",
+ BaseName(""): BaseName("_out_null"), # XXX
+ BaseName("__wrap_sprintf"): BaseName("_out_buffer"),
+ BaseName("__wrap_snprintf"): BaseName("_out_buffer"),
+ BaseName("__wrap_vsnprintf"): BaseName("_out_buffer"),
+ BaseName("vfctprintf"): BaseName("_out_fct"),
}
self.known_fct = {
- "stdio_vprintf": "stdio_buffered_printer",
- "__wrap_vprintf": "stdio_buffered_printer",
+ BaseName("stdio_vprintf"): BaseName("stdio_buffered_printer"),
+ BaseName("__wrap_vprintf"): BaseName("stdio_buffered_printer"),
}
def is_intrhandler(self, name: QName) -> bool:
@@ -936,7 +940,7 @@ class PicoFmtPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -951,28 +955,28 @@ class PicoFmtPlugin:
m = re_call_other.fullmatch(line)
call: str | None = m.group("func") if m else None
if call == "out":
- return [QName(x) for x in self.known_out.values()], False
+ return [x.as_qname() for x in self.known_out.values()], False
if "->fct" in line:
- return [QName(x) for x in self.known_fct.values()], False
+ return [x.as_qname() for x in self.known_fct.values()], False
return None
def skip_call(self, chain: typing.Sequence[QName], call: QName) -> bool:
- if str(call.base()) in self.known_out.values():
- out = ""
+ if call.base() in self.known_out.values():
+ out: BaseName | None = None
for pcall in chain:
- if str(pcall.base()) in self.known_out:
- out = self.known_out[str(pcall.base())]
- if (
- out == "_out_buffer" and str(call.base()) == "_out_null"
+ if pcall.base() in self.known_out:
+ out = self.known_out[pcall.base()]
+ if out == BaseName("_out_buffer") and call.base() == BaseName(
+ "_out_null"
): # XXX: Gross hack
- out = "_out_null"
- return str(call.base()) != out
- if str(call.base()) in self.known_fct.values():
- fct = ""
+ out = BaseName("_out_null")
+ return call.base() != out
+ if call.base() in self.known_fct.values():
+ fct: BaseName | None = None
for pcall in chain:
- if str(pcall.base()) in self.known_fct:
- fct = self.known_fct[str(pcall.base())]
- return str(call.base()) != fct
+ if pcall.base() in self.known_fct:
+ fct = self.known_fct[pcall.base()]
+ return call.base() != fct
return False
@@ -1019,20 +1023,20 @@ class PicoSDKPlugin:
]
def is_intrhandler(self, name: QName) -> bool:
- return str(name.base()) in [
- "isr_invalid",
- "isr_nmi",
- "isr_hardfault",
- "isr_svcall",
- "isr_pendsv",
- "isr_systick",
- *[f"isr_irq{n}" for n in range(32)],
+ return name.base() in [
+ BaseName("isr_invalid"),
+ BaseName("isr_nmi"),
+ BaseName("isr_hardfault"),
+ BaseName("isr_svcall"),
+ BaseName("isr_pendsv"),
+ BaseName("isr_systick"),
+ *[BaseName(f"isr_irq{n}") for n in range(32)],
]
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def indirect_callees(
@@ -1285,7 +1289,7 @@ class TinyUSBDevicePlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -1324,10 +1328,10 @@ class NewlibPlugin:
def init_array(self) -> typing.Collection[QName]:
return [QName("register_fini")]
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return [
# register_fini() calls atexit(__libc_fini_array)
- "__libc_fini_array",
+ BaseName("__libc_fini_array"),
]
def extra_nodes(self) -> typing.Collection[Node]:
@@ -1391,7 +1395,7 @@ class LibGCCPlugin:
def init_array(self) -> typing.Collection[QName]:
return []
- def extra_includes(self) -> typing.Collection[str]:
+ def extra_includes(self) -> typing.Collection[BaseName]:
return []
def extra_nodes(self) -> typing.Collection[Node]:
@@ -1429,14 +1433,16 @@ def main(
lib9p_plugin = Lib9PPlugin(arg_base_dir, arg_c_fnames)
def sbc_is_thread(name: QName) -> int:
- if str(name).endswith("_cr") and str(name.base()) != "lib9p_srv_read_cr":
+ if str(name).endswith("_cr") and name.base() != BaseName("lib9p_srv_read_cr"):
if "9p" in str(name.base()) or "lib9p/tests/test_server/main.c:" in str(
name
):
return lib9p_plugin.thread_count(name)
return 1
- if str(name.base()) == (
- "_entry_point" if arg_pico_platform == "rp2040" else "main"
+ if name.base() == (
+ BaseName("_entry_point")
+ if arg_pico_platform == "rp2040"
+ else BaseName("main")
):
return 1
return 0
@@ -1485,27 +1491,33 @@ def main(
return 0, False
def misc_filter(name: QName) -> tuple[int, bool]:
- if str(name.base()) in ["__lm_printf", "__assert_msg_fail"]:
+ if name.base() in [
+ BaseName("__lm_printf"),
+ BaseName("__assert_msg_fail"),
+ ]:
return 1, False
return 0, False
- extra_includes: list[str] = []
+ extra_includes: list[BaseName] = []
for plugin in plugins:
extra_includes.extend(plugin.extra_includes())
def extra_filter(name: QName) -> tuple[int, bool]:
nonlocal extra_includes
- if str(name.base()) in extra_includes:
+ if name.base() in extra_includes:
return 1, True
return 0, False
- def location_xform(loc: str) -> str:
+ def _str_location_xform(loc: str) -> str:
if not loc.startswith("/"):
return loc
parts = loc.split(":", 1)
parts[0] = "./" + os.path.relpath(parts[0], arg_base_dir)
return ":".join(parts)
+ def location_xform(_loc: QName) -> str:
+ return _str_location_xform(str(_loc))
+
result = analyze(
ci_fnames=arg_ci_fnames,
app_func_filters={
@@ -1514,7 +1526,7 @@ def main(
"Misc": misc_filter,
"Extra": extra_filter,
},
- app=PluginApplication(location_xform, plugins),
+ app=PluginApplication(_str_location_xform, plugins),
cfg_max_call_depth=100,
)
@@ -1526,7 +1538,7 @@ def main(
# Figure sizes.
namelen = max(
- [len(location_xform(str(k))) for k in grp.rows.keys()] + [len(grp_name) + 4]
+ [len(location_xform(k)) for k in grp.rows.keys()] + [len(grp_name) + 4]
)
numlen = len(str(nsum))
sep1 = ("=" * namelen) + " " + "=" * numlen
@@ -1535,7 +1547,7 @@ def main(
# Print.
print("= " + grp_name + " " + sep1[len(grp_name) + 3 :])
for qname, val in sorted(grp.rows.items()):
- name = location_xform(str(qname))
+ name = location_xform(qname)
if val.nstatic == 0:
continue
print(
@@ -1573,7 +1585,7 @@ def main(
name = str(funcname.base())
base = val.nstatic
size = base + intrstack
- if name in ("main", "_entry_point"):
+ if name in ["main", "_entry_point"]:
mainrow = CrRow(name=name, cnt=1, base=base, size=size)
else:
size = next_power_of_2(size + stack_guard_size) - stack_guard_size
@@ -1618,9 +1630,9 @@ def main(
print_group("Misc")
for funcname in sorted(result.missing):
- print(f"warning: missing: {location_xform(str(funcname))}")
+ print(f"warning: missing: {location_xform(funcname)}")
for funcname in sorted(result.dynamic):
- print(f"warning: dynamic-stack-usage: {location_xform(str(funcname))}")
+ print(f"warning: dynamic-stack-usage: {location_xform(funcname)}")
print("*/")
print("")