diff options
Diffstat (limited to 'build-aux/measurestack/testutil.py')
-rw-r--r-- | build-aux/measurestack/testutil.py | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/build-aux/measurestack/testutil.py b/build-aux/measurestack/testutil.py new file mode 100644 index 0000000..751e57f --- /dev/null +++ b/build-aux/measurestack/testutil.py @@ -0,0 +1,131 @@ +# build-aux/measurestack/testutil.py - Utilities for writing tests +# +# Copyright (C) 2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import typing + +from . import analyze, util + +# pylint: disable=unused-variable +__all__ = [ + "aprime_gen", + "aprime_decompose", + "NopPlugin", + "GraphProviderPlugin", + "nop_location_xform", +] + + +def aprime_gen(l: int, n: int) -> typing.Sequence[int]: + """Return an `l`-length sequence of nonnegative + integers such that any `n`-length-or-shorter combination of + members with repeats allowed can be uniquely identified by its + sum. + + (If that were "product" instead of "sum", the obvious solution + would be the first `l` primes.) + + """ + seq = [1] + while len(seq) < l: + x = seq[-1] * n + 1 + seq.append(x) + return seq + + +def aprime_decompose( + aprimes: typing.Sequence[int], tot: int +) -> tuple[typing.Collection[int], typing.Collection[int]]: + ret_idx = [] + ret_val = [] + while tot: + idx = max(i for i in range(len(aprimes)) if aprimes[i] <= tot) + val = aprimes[idx] + ret_idx.append(idx) + ret_val.append(val) + tot -= val + return ret_idx, ret_val + + +class NopPlugin: + def is_intrhandler(self, name: analyze.QName) -> bool: + return False + + def init_array(self) -> typing.Collection[analyze.QName]: + return [] + + def extra_includes(self) -> typing.Collection[analyze.BaseName]: + return [] + + def indirect_callees( + self, loc: str, line: str + ) -> tuple[typing.Collection[analyze.QName], bool] | None: + return None + + def skipmodels(self) -> dict[analyze.BaseName, analyze.SkipModel]: + return {} + + def extra_nodes(self) -> typing.Collection[analyze.Node]: + return [] + + +class GraphProviderPlugin(NopPlugin): + _nodes: typing.Sequence[analyze.Node] + + def __init__( + self, + max_call_depth: int, + graph: typing.Sequence[tuple[str, typing.Collection[str]]], + ) -> None: + seq = aprime_gen(len(graph), max_call_depth) + nodes: list[analyze.Node] = [] + for i, (name, calls) in enumerate(graph): + nodes.append(util.synthetic_node(name, seq[i], calls)) + assert ( + len(graph) + == len(nodes) + == len(set(n.nstatic for n in nodes)) + == len(set(str(n.funcname.base()) for n in nodes)) + ) + self._nodes = nodes + + def extra_nodes(self) -> typing.Collection[analyze.Node]: + return self._nodes + + def decode_nstatic(self, tot: int) -> typing.Collection[str]: + idxs, _ = aprime_decompose([n.nstatic for n in self._nodes], tot) + return [str(self._nodes[i].funcname.base()) for i in idxs] + + def encode_nstatic(self, calls: typing.Collection[str]) -> int: + tot = 0 + d: dict[str, int] = {} + for node in self._nodes: + d[str(node.funcname.base())] = node.nstatic + print(d) + for call in calls: + tot += d[call] + return tot + + def sorted_calls(self, calls: typing.Collection[str]) -> typing.Sequence[str]: + d: dict[str, int] = {} + for node in self._nodes: + d[str(node.funcname.base())] = node.nstatic + + def k(call: str) -> int: + return d[call] + + return sorted(calls, key=k) + + def assert_nstatic(self, act_tot: int, exp_calls: typing.Collection[str]) -> None: + exp_tot = self.encode_nstatic(exp_calls) + if act_tot != exp_tot: + act_str = f"{act_tot}: {self.sorted_calls(self.decode_nstatic(act_tot))}" + exp_str = f"{exp_tot}: {self.sorted_calls(exp_calls)}" + assert ( + False + ), f"act:{act_tot} != exp:{exp_tot}\n\t-exp = {exp_str}\n\t+act = {act_str}" + + +def nop_location_xform(loc: str) -> str: + return loc |