diff options
Diffstat (limited to 'build-aux/measurestack/analyze.py')
-rw-r--r-- | build-aux/measurestack/analyze.py | 209 |
1 files changed, 92 insertions, 117 deletions
diff --git a/build-aux/measurestack/analyze.py b/build-aux/measurestack/analyze.py index 67c44ce..3dc1a26 100644 --- a/build-aux/measurestack/analyze.py +++ b/build-aux/measurestack/analyze.py @@ -27,7 +27,10 @@ __all__ = [ "QName", "UsageKind", "Node", + "NodeHandler", + "NodeHandleCB", "maybe_sorted", + "handle_simple_node", "AnalyzeResultVal", "AnalyzeResultGroup", "AnalyzeResult", @@ -229,41 +232,16 @@ class AnalyzeResult(typing.NamedTuple): included_funcs: set[QName] -class SkipModel(typing.NamedTuple): - """Running the skipmodel calls `.fn(chain, ...)` with the chain - consisting of the last few items of the input chain. - - If `.nchain` is an int: - - - the chain is the last `.nchain` items or the input chain. If - the input chain is not that long, then `.fn` is not called and - the call is *not* skipped. - - If `.nchain` is a collection: - - - the chain starts with the *last* occurance of `.nchain` in the - input chain. If the input chain does not contain a member of - the collection, then .fn is called with an empty chain. - """ +class NodeHandleCB(typing.Protocol): + def __call__( + self, chain: typing.Sequence[QName], missing_ok: bool = False + ) -> tuple[int, bool]: ... - nchain: int | typing.Collection[BaseName] - fn: typing.Callable[[typing.Sequence[QName], Node, QName], bool] +class NodeHandler(typing.Protocol): def __call__( - self, chain: typing.Sequence[QName], node: Node, call: QName - ) -> tuple[bool, int]: - match self.nchain: - case int(): - if len(chain) >= self.nchain: - _chain = chain[-self.nchain :] - return self.fn(_chain, node, call), len(_chain) + 1 - return False, 0 - case _: - for i in reversed(range(len(chain))): - if chain[i].base() in self.nchain: - _chain = chain[i:] - return self.fn(_chain, node, call), len(_chain) + 1 - return self.fn([], node, call), 1 + self, handle: NodeHandleCB, chain: typing.Sequence[QName], node: Node + ) -> tuple[int, bool]: ... class Application(typing.Protocol): @@ -271,7 +249,7 @@ class Application(typing.Protocol): def indirect_callees( self, elem: vcg.VCGElem ) -> tuple[typing.Collection[QName], bool]: ... - def skipmodels(self) -> dict[BaseName, SkipModel]: ... + def node_handlers(self) -> dict[BaseName, NodeHandler]: ... # code ######################################################################### @@ -450,6 +428,28 @@ def _make_graph( return ret +def handle_simple_node( + handle: NodeHandleCB, + chain: typing.Sequence[QName], + node: Node, + skip_p: typing.Callable[[QName], bool] | None = None, +) -> tuple[int, bool]: + cacheable = True + max_call_nstatic = 0 + for call_qname, call_missing_ok in maybe_sorted(node.calls.items()): + if skip_p and skip_p(call_qname): + if dbg_nstatic: + print(f"//dbg-nstatic: {'- '*(len(chain)+1)}{call_qname}\tskip") + continue + call_nstatic, call_cacheable = handle( + [*chain, node.funcname, call_qname], call_missing_ok + ) + max_call_nstatic = max(max_call_nstatic, call_nstatic) + if not call_cacheable: + cacheable = False + return node.nstatic + max_call_nstatic, cacheable + + def analyze( *, ci_fnames: typing.Collection[str], @@ -467,94 +467,69 @@ def analyze( track_inclusion: bool = True - skipmodels = app.skipmodels() - for name, model in skipmodels.items(): - if not isinstance(model.nchain, int): - assert len(model.nchain) > 0 - - _nstatic_cache: dict[QName, int] = {} - - def _nstatic(chain: list[QName], funcname: QName) -> tuple[int, int]: - nonlocal track_inclusion - - assert funcname in graphdata.graph - - def putdbg(msg: str) -> None: - print(f"//dbg-nstatic: {'- '*len(chain)}{msg}") - - node = graphdata.graph[funcname] - if dbg_nstatic: - putdbg(f"{funcname}\t{node.nstatic}") + node_handlers = app.node_handlers() + + def default_node_handler( + handle: NodeHandleCB, chain: typing.Sequence[QName], node: Node + ) -> tuple[int, bool]: + return handle_simple_node(handle, chain, node) + + nstatic_cache: dict[QName, int] = {} + + def node_handle_cb( + chain: typing.Sequence[QName], missing_ok: bool = False + ) -> tuple[int, bool]: + nonlocal nstatic_cache + assert len(chain) > 0 + + if len(chain) > cfg_max_call_depth: + raise ValueError(f"max call depth exceeded: {chain}") + + call_orig_qname = chain[-1] + call_qname = graphdata.resolve_funcname(call_orig_qname) + + def dbglog(msg: str) -> None: + if dbg_nstatic: + print( + f"//dbg-nstatic: {'- '*(len(chain)-1)}{call_qname or call_orig_qname}\t{msg}" + ) + + if not call_qname: + if not missing_ok: + missing.add(call_orig_qname) + dbglog("missing") + nstatic_cache[call_orig_qname] = 0 + return 0, True + + assert call_qname in graphdata.graph + if (not dbg_nocache) and call_qname in nstatic_cache: + nstatic = nstatic_cache[call_qname] + dbglog(f"total={nstatic} (cache-read)") + return nstatic, True + node = graphdata.graph[call_qname] + dbglog(str(node.nstatic)) if node.usage_kind == "dynamic" or node.ndynamic > 0: - dynamic.add(funcname) + dynamic.add(call_qname) if track_inclusion: - included_funcs.add(funcname) - - max_call_nstatic = 0 - max_call_nchain = 0 - - if node.calls: - skipmodel = skipmodels.get(funcname.base()) - chain.append(funcname) - if len(chain) == cfg_max_call_depth: - raise ValueError(f"max call depth exceeded: {chain}") - for call_orig_qname, call_missing_ok in node.calls.items(): - skip_nchain = 0 - # 1. Resolve - call_qname = graphdata.resolve_funcname(call_orig_qname) - if not call_qname: - if skipmodel: - skip, _ = skipmodel(chain[:-1], node, call_orig_qname) - if skip: - if dbg_nstatic: - putdbg(f"{call_orig_qname}\tskip missing") - continue - if not call_missing_ok: - missing.add(call_orig_qname) - if dbg_nstatic: - putdbg(f"{call_orig_qname}\tmissing") - continue - - # 2. Skip - if skipmodel: - skip, skip_nchain = skipmodel(chain[:-1], node, call_qname) - max_call_nchain = max(max_call_nchain, skip_nchain) - if skip: - if dbg_nstatic: - putdbg(f"{call_qname}\tskip") - continue - - # 3. Call - if ( - (not dbg_nocache) - and skip_nchain == 0 - and call_qname in _nstatic_cache - ): - call_nstatic = _nstatic_cache[call_qname] - if dbg_nstatic: - putdbg(f"{call_qname}\ttotal={call_nstatic} (cache-read)") - max_call_nstatic = max(max_call_nstatic, call_nstatic) - else: - call_nstatic, call_nchain = _nstatic(chain, call_qname) - max_call_nstatic = max(max_call_nstatic, call_nstatic) - max_call_nchain = max(max_call_nchain, call_nchain) - if skip_nchain == 0 and call_nchain == 0: - if dbg_nstatic: - putdbg(f"{call_qname}\ttotal={call_nstatic} (cache-write)") - if call_qname not in _nstatic_cache: - if dbg_cache: - print(f"//dbg-cache: {call_qname} = {call_nstatic}") - _nstatic_cache[call_qname] = call_nstatic - else: - assert dbg_nocache - assert _nstatic_cache[call_qname] == call_nstatic - elif dbg_nstatic: - putdbg(f"{call_qname}\ttotal={call_nstatic} (do-not-cache)") - chain.pop() - return node.nstatic + max_call_nstatic, max(0, max_call_nchain - 1) + included_funcs.add(call_qname) + + handler = node_handlers.get(call_qname.base(), default_node_handler) + nstatic, cacheable = handler(node_handle_cb, chain[:-1], node) + if cacheable: + dbglog(f"total={nstatic} (cache-write)") + if call_qname not in nstatic_cache: + if dbg_cache: + print(f"//dbg-cache: {call_qname} = {nstatic}") + nstatic_cache[call_qname] = nstatic + else: + assert dbg_nocache + assert nstatic_cache[call_qname] == nstatic + else: + dbglog(f"total={nstatic} (do-not-cache)") + return nstatic, cacheable def nstatic(funcname: QName) -> int: - return _nstatic([], funcname)[0] + return node_handle_cb([funcname])[0] groups: dict[str, AnalyzeResultGroup] = {} for grp_name, grp_filter in app_func_filters.items(): |